diff --git a/LICENSE b/LICENSE index 5a8c78b98b2b..9714b3b1e4d1 100644 --- a/LICENSE +++ b/LICENSE @@ -257,9 +257,8 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware.kryo:kryo:2.21 - http://code.google.com/p/kryo/) - (New BSD License) MinLog (com.esotericsoftware.minlog:minlog:1.2 - http://code.google.com/p/minlog/) - (New BSD License) ReflectASM (com.esotericsoftware.reflectasm:reflectasm:1.07 - http://code.google.com/p/reflectasm/) + (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) + (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f3152cc23222..31bca1658045 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -17,10 +17,10 @@ # mllib.R: Provides methods for MLlib integration -#' @title S4 class that represents a PipelineModel -#' @param model A Java object reference to the backing Scala PipelineModel +#' @title S4 class that represents a generalized linear model +#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper #' @export -setClass("PipelineModel", representation(model = "jobj")) +setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' @title S4 class that represents a NaiveBayesModel #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper @@ -39,21 +39,18 @@ setClass("KMeansModel", representation(jobj = "jobj")) #' Fits a generalized linear model #' -#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' Fits a generalized linear model, similarly to R's glm(). #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data DataFrame for training -#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. -#' @param lambda Regularization parameter -#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) -#' @param standardize Whether to standardize features before training -#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and -#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory -#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an -#' analytical solution to the linear regression problem. The default value is "auto" -#' which means that the solver algorithm is selected automatically. -#' @return a fitted MLlib model +#' @param data DataFrame for training. +#' @param family A description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param epsilon Positive convergence tolerance of iterations. +#' @param maxit Integer giving the maximal number of IRLS iterations. +#' @return a fitted generalized linear model #' @rdname glm #' @export #' @examples @@ -64,25 +61,59 @@ setClass("KMeansModel", representation(jobj = "jobj")) #' df <- createDataFrame(sqlContext, iris) #' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian") #' summary(model) -#'} +#' } setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, - standardize = TRUE, solver = "auto") { - family <- match.arg(family) + function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) { + if (is.character(family)) { + family <- get(family, mode = "function", envir = parent.frame()) + } + if (is.function(family)) { + family <- family() + } + if (is.null(family$family)) { + print(family) + stop("'family' not recognized") + } + formula <- paste(deparse(formula), collapse = "") - model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", formula, data@sdf, family, lambda, - alpha, standardize, solver) - return(new("PipelineModel", model = model)) + + jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", + "fit", formula, data@sdf, family$family, family$link, + epsilon, as.integer(maxit)) + return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) -#' Make predictions from a model +#' Get the summary of a generalized linear model #' -#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param object A fitted MLlib model +#' @param object A fitted generalized linear model +#' @return coefficients the model's coefficients, intercept +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#' } +setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) + +#' Make predictions from a generalized linear model +#' +#' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict(). +#' +#' @param object A fitted generalized linear model #' @param newData DataFrame for testing -#' @return DataFrame containing predicted values +#' @return DataFrame containing predicted labels in a column named "prediction" #' @rdname predict #' @export #' @examples @@ -90,10 +121,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' model <- glm(y ~ x, trainingData) #' predicted <- predict(model, testData) #' showDF(predicted) -#'} -setMethod("predict", signature(object = "PipelineModel"), +#' } +setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { - return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) #' Make predictions from a naive Bayes model @@ -116,54 +147,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) -#' Get the summary of a model -#' -#' Returns the summary of a model produced by glm(), similarly to R's summary(). -#' -#' @param object A fitted MLlib model -#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family -#' or a list with 'coefficients' component for binomial family. \cr -#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals -#' of the estimation, the 'coefficients' gives the estimated coefficients and their -#' estimated standard errors, t values and p-values. (It only available when model -#' fitted by normal solver.) \cr -#' For binomial family: the 'coefficients' gives the estimated coefficients. -#' See summary.glm for more information. \cr -#' @rdname summary -#' @export -#' @examples -#' \dontrun{ -#' model <- glm(y ~ x, trainingData) -#' summary(model) -#'} -setMethod("summary", signature(object = "PipelineModel"), - function(object, ...) { - modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) - features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) - coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", object@model) - if (modelName == "LinearRegressionModel") { - devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelDevianceResiduals", object@model) - devianceResiduals <- matrix(devianceResiduals, nrow = 1) - colnames(devianceResiduals) <- c("Min", "Max") - rownames(devianceResiduals) <- rep("", times = 1) - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) - return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) - } else if (modelName == "LogisticRegressionModel") { - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) - } else { - stop(paste("Unsupported model", modelName, sep = " ")) - } - }) - #' Get the summary of a naive Bayes model #' #' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary(). diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index fdb591756e3f..a9dbd2bdc4cc 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -25,20 +25,21 @@ sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) -test_that("glm and predict", { +test_that("formula of glm", { training <- suppressWarnings(createDataFrame(sqlContext, iris)) - test <- select(training, "Sepal_Length") - model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") - prediction <- predict(model, test) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + # dot minus and intercept vs native glm + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - # Test stats::predict is working - x <- rnorm(15) - y <- x + rnorm(15) - expect_equal(length(predict(lm(y ~ x))), 15) -}) + # feature interaction vs native glm + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -test_that("glm should work with long formula", { + # glm should work with long formula training <- suppressWarnings(createDataFrame(sqlContext, iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length @@ -50,68 +51,30 @@ test_that("glm should work with long formula", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) -test_that("predictions match with native glm", { +test_that("glm and predict", { training <- suppressWarnings(createDataFrame(sqlContext, iris)) + # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - vals <- collect(select(predict(model, training), "prediction")) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("dot minus and intercept vs native glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) -test_that("feature interaction vs native glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + # poisson family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) -test_that("summary coefficients match with native glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) - coefs <- unlist(stats$coefficients) - devianceResiduals <- unlist(stats$devianceResiduals) - - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) - rCoefs <- unlist(rStats$coefficients) - rDevianceResiduals <- c(-0.95096, 0.72918) - - expect_true(all(abs(rCoefs - coefs) < 1e-5)) - expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) -}) - -test_that("summary coefficients match with native glm of family 'binomial'", { - df <- suppressWarnings(createDataFrame(sqlContext, iris)) - training <- filter(df, df$Species != "setosa") - stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, - family = "binomial")) - coefs <- as.vector(stats$coefficients[, 1]) - - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] - rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit")))) - - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) -}) - -test_that("summary works on base GLM models", { - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) - baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("kmeans", { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 22eb3ec98467..d747d4f83f24 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1853,7 +1853,7 @@ test_that("approxQuantile() on a DataFrame", { test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table not found", retError), TRUE) + expect_equal(grepl("Table or View not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) }) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index b5a9d6671f7c..a27aaf2b277f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -123,16 +123,15 @@ public TransportClientFactory( public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. - long preResolveHost = System.nanoTime(); - final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); - long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; - logger.info("Spent {} ms to resolve {}", hostResolveTimeMs, address); + // Use unresolved address here to avoid DNS resolution each time we creates a client. + final InetSocketAddress unresolvedAddress = + InetSocketAddress.createUnresolved(remoteHost, remotePort); // Create the ClientPool if we don't have it yet. - ClientPool clientPool = connectionPool.get(address); + ClientPool clientPool = connectionPool.get(unresolvedAddress); if (clientPool == null) { - connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer)); - clientPool = connectionPool.get(address); + connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer)); + clientPool = connectionPool.get(unresolvedAddress); } int clientIndex = rand.nextInt(numConnectionsPerPeer); @@ -149,25 +148,35 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); + logger.trace("Returning cached connection to {}: {}", + cachedClient.getSocketAddress(), cachedClient); return cachedClient; } } // If we reach here, we don't have an existing connection open. Let's create a new one. // Multiple threads might race here to create new connections. Keep only one of them active. + final long preResolveHost = System.nanoTime(); + final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort); + final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; + if (hostResolveTimeMs > 2000) { + logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); + } else { + logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); + } + synchronized (clientPool.locks[clientIndex]) { cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null) { if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); + logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); return cachedClient; } else { - logger.info("Found inactive connection to {}, creating a new one.", address); + logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } - clientPool.clients[clientIndex] = createClient(address); + clientPool.clients[clientIndex] = createClient(resolvedAddress); return clientPool.clients[clientIndex]; } } diff --git a/core/pom.xml b/core/pom.xml index 4c7e3a36620a..7349ad35b959 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -192,7 +192,6 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.10 com.sun.jersey diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java index 23673d3e3d7a..3fcb52f61583 100644 --- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java +++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java @@ -34,7 +34,7 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2); - public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1); + public static final StorageLevel OFF_HEAP = create(true, true, true, false, 1); /** * Create a new StorageLevel object. diff --git a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java similarity index 94% rename from core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java rename to core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java index 27b6f0d4a388..8783b5f56eba 100644 --- a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java +++ b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java @@ -20,20 +20,17 @@ import java.io.InputStream; import java.util.zip.Checksum; -import net.jpountz.lz4.LZ4BlockOutputStream; import net.jpountz.lz4.LZ4Exception; import net.jpountz.lz4.LZ4Factory; import net.jpountz.lz4.LZ4FastDecompressor; import net.jpountz.util.SafeUtils; -import net.jpountz.xxhash.StreamingXXHash32; -import net.jpountz.xxhash.XXHash32; import net.jpountz.xxhash.XXHashFactory; /** * {@link InputStream} implementation to decode data written with - * {@link LZ4BlockOutputStream}. This class is not thread-safe and does not + * {@link net.jpountz.lz4.LZ4BlockOutputStream}. This class is not thread-safe and does not * support {@link #mark(int)}/{@link #reset()}. - * @see LZ4BlockOutputStream + * @see net.jpountz.lz4.LZ4BlockOutputStream * * This is based on net.jpountz.lz4.LZ4BlockInputStream * @@ -90,12 +87,13 @@ public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Che } /** - * Create a new instance using {@link XXHash32} for checksuming. + * Create a new instance using {@link net.jpountz.xxhash.XXHash32} for checksuming. * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) - * @see StreamingXXHash32#asChecksum() + * @see net.jpountz.xxhash.StreamingXXHash32#asChecksum() */ public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { - this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); + this(in, decompressor, + XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); } /** diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 81ee7ab58ab5..3c2980e442ab 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -215,8 +215,6 @@ private void writeSortedFile(boolean isLastFile) throws IOException { } } - inMemSorter.reset(); - if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -255,6 +253,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { writeSortedFile(false); final long spillSize = freeMemory(); + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index fe79ff0e3052..76b0e6a304ac 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -51,9 +51,12 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { */ private int pos = 0; + private int initialSize; + ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { this.consumer = consumer; assert (initialSize > 0); + this.initialSize = initialSize; this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } @@ -70,6 +73,10 @@ public int numRecords() { } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 32958be7a7fd..6807710f9fef 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -716,7 +716,8 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); offset += vlen; - Platform.putLong(base, offset, 0); + // put this value at the beginning of the list + Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0); // --- Update bookkeeping data structures ---------------------------------------------------- offset = currentPage.getBaseOffset(); @@ -724,17 +725,12 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff pageCursor += recordLength; final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( currentPage, recordOffset); + longArray.set(pos * 2, storedKeyAddress); + updateAddressesAndSizes(storedKeyAddress); numValues++; - if (isDefined) { - // put this pair at the end of chain - while (nextValue()) { /* do nothing */ } - Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress); - nextValue(); // point to new added value - } else { + if (!isDefined) { numKeys++; - longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); - updateAddressesAndSizes(storedKeyAddress); isDefined = true; if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index ded8f0472b27..ef79b4908347 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -200,14 +200,17 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); - - inMemSorter.reset(); } final long spillSize = freeMemory(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 145c3a195064..01eae0e8dc14 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -84,6 +84,8 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { */ private int pos = 0; + private long initialSize; + public UnsafeInMemorySorter( final MemoryConsumer consumer, final TaskMemoryManager memoryManager, @@ -102,6 +104,7 @@ public UnsafeInMemorySorter( LongArray array) { this.consumer = consumer; this.memoryManager = memoryManager; + this.initialSize = array.size(); if (recordComparator != null) { this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); @@ -123,6 +126,10 @@ public void free() { } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; } diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 48f86d1536c9..47dd9162a1bf 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -106,21 +106,22 @@ pre { line-height: 18px; padding: 6px; margin: 0; + word-break: break-word; border-radius: 3px; } .stage-details { - max-height: 100px; overflow-y: auto; margin: 0; + display: block; transition: max-height 0.25s ease-out, padding 0.25s ease-out; } .stage-details.collapsed { - max-height: 0; padding-top: 0; padding-bottom: 0; border: none; + display: none; } .description-input { @@ -143,14 +144,15 @@ pre { max-height: 300px; overflow-y: auto; margin: 0; + display: block; transition: max-height 0.25s ease-out, padding 0.25s ease-out; } .stacktrace-details.collapsed { - max-height: 0; padding-top: 0; padding-bottom: 0; border: none; + display: none; } span.expand-additional-metrics, span.expand-dag-viz { diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 8fc657c5ebe4..76692ccec815 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -278,9 +278,9 @@ private object ContextCleaner { * Listener class used for testing when any item has been cleaned by the Cleaner class. */ private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) - def broadcastCleaned(broadcastId: Long) - def accumCleaned(accId: Long) - def checkpointCleaned(rddId: Long) + def rddCleaned(rddId: Int): Unit + def shuffleCleaned(shuffleId: Int): Unit + def broadcastCleaned(broadcastId: Long): Unit + def accumCleaned(accId: Long): Unit + def checkpointCleaned(rddId: Long): Unit } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index ce11772a6d8d..339266a5d48b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -41,7 +41,7 @@ trait FutureAction[T] extends Future[T] { /** * Cancels the execution of this action. */ - def cancel() + def cancel(): Unit /** * Blocks until this action completes. @@ -65,7 +65,7 @@ trait FutureAction[T] extends Future[T] { * When this action is completed, either through an exception, or a value, applies the provided * function. */ - def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit /** * Returns whether the action has already been completed with a value or an exception. diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 9fad1f6786ad..982b6d6b6173 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -25,6 +25,7 @@ import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} +import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.util.thread.QueuedThreadPool @@ -155,6 +156,12 @@ private[spark] class HttpServer( throw new ServerStateException("Server is already stopped") } else { server.stop() + // Stop the ThreadPool if it supports stop() method (through LifeCycle). + // It is needed because stopping the Server won't stop the ThreadPool it uses. + val threadPool = server.getThreadPool + if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { + threadPool.asInstanceOf[LifeCycle].stop + } port = -1 server = null } diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 7aa9057858a0..0dd4ec656f5a 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -187,7 +187,7 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(sc: SparkContext): Seq[Accumulator[_]] = { + def createAll(sc: SparkContext): Seq[Accumulator[_]] = { val accums = createAll() accums.foreach { accum => Accumulators.register(accum) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ec5cedf258f..e41088f7c8f6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Set a local property that affects jobs submitted from this thread, such as the - * Spark fair scheduler pool. + * Set a local property that affects jobs submitted from this thread, such as the Spark fair + * scheduler pool. User-defined properties may also be set here. These properties are propagated + * through to worker tasks and can be accessed there via + * [[org.apache.spark.TaskContext#getLocalProperty]]. */ def setLocalProperty(key: String, value: String) { if (value == null) { @@ -721,7 +723,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (safeEnd - safeStart) / step + 1 } } - parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex { (i, _) => val partitionStart = (i * numElements) / numSlices * step + start val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start def getSafeMargin(bi: BigInt): Long = @@ -760,7 +762,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli ret } } - }) + } } /** Distribute a local Scala collection to form an RDD. @@ -2395,9 +2397,8 @@ object SparkContext extends Logging { } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } val backend = try { val clazz = @@ -2405,9 +2406,8 @@ object SparkContext extends Logging { val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } scheduler.initialize(backend) (backend, scheduler) @@ -2419,9 +2419,8 @@ object SparkContext extends Logging { cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } val backend = try { @@ -2430,9 +2429,8 @@ object SparkContext extends Logging { val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } scheduler.initialize(backend) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 700e2cb3f91b..3d11db7461c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -101,14 +101,13 @@ class SparkEnv ( // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the // current working dir in executor which we do not need to delete. driverTmpDirToDelete match { - case Some(path) => { + case Some(path) => try { Utils.deleteRecursively(new File(path)) } catch { case e: Exception => logWarning(s"Exception while deleting Spark temp dir: $path", e) } - } case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } @@ -314,7 +313,8 @@ object SparkEnv extends Logging { UnifiedMemoryManager(conf, numUsableCores) } - val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) + val blockTransferService = + new NettyBlockTransferService(conf, securityManager, hostname, numUsableCores) val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index bfcacbf229b0..757c1b5116f3 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.Serializable +import java.util.Properties import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -64,7 +65,7 @@ object TaskContext { * An empty task context that does not represent an actual task. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, null) + new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) } } @@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable { */ def taskAttemptId(): Long + /** + * Get a local property set upstream in the driver, or null if it is missing. See also + * [[org.apache.spark.SparkContext.setLocalProperty]]. + */ + def getLocalProperty(key: String): String + @DeveloperApi def taskMetrics(): TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c9354b3e5574..fa0b2d3d2829 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.util.Properties + import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics @@ -32,6 +34,7 @@ private[spark] class TaskContextImpl( override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, @transient private val metricsSystem: MetricsSystem, initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll()) extends TaskContext @@ -118,6 +121,8 @@ private[spark] class TaskContextImpl( override def isInterrupted(): Boolean = interrupted + override def getLocalProperty(key: String): String = localProperties.getProperty(key) + override def getMetricsSources(sourceName: String): Seq[Source] = metricsSystem.getSourcesByName(sourceName) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 6f6730690f85..6259bead3ea8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -134,11 +134,10 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable - case array: Array[Any] => { + case array: Array[Any] => val arrayWriteable = new ArrayWritable(classOf[Writable]) arrayWriteable.set(array.map(convertToWritable(_))) arrayWriteable - } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4bca16a23443..ab5b6c8380e8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -470,7 +470,7 @@ private[spark] object PythonRDD extends Logging { objs.append(obj) } } catch { - case eof: EOFException => {} + case eof: EOFException => // No-op } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } finally { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 41ac308808c8..cda9d38c6a82 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -152,10 +152,9 @@ class SparkHadoopUtil extends Logging { val baselineBytesRead = f() Some(() => f() - baselineBytesRead) } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) None - } } } @@ -174,10 +173,9 @@ class SparkHadoopUtil extends Logging { val baselineBytesWritten = f() Some(() => f() - baselineBytesWritten) } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) None - } } } @@ -315,7 +313,7 @@ class SparkHadoopUtil extends Logging { */ def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { text match { - case HADOOP_CONF_PATTERN(matched) => { + case HADOOP_CONF_PATTERN(matched) => logDebug(text + " matched " + HADOOP_CONF_PATTERN) val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } val eval = Option[String](hadoopConf.get(key)) @@ -330,11 +328,9 @@ class SparkHadoopUtil extends Logging { // Continue to substitute more variables. substituteHadoopVariables(eval.get, hadoopConf) } - } - case _ => { + case _ => logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) text - } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala index e584952a9ad8..94506a0cbb27 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala @@ -33,7 +33,8 @@ private[spark] trait AppClientListener { /** An application death is an unrecoverable failure condition. */ def dead(reason: String): Unit - def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) + def executorAdded( + fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index 70f21fbe0de8..52e2854961ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -32,8 +32,8 @@ trait LeaderElectionAgent { @DeveloperApi trait LeaderElectable { - def electedLeader() - def revokedLeadership() + def electedLeader(): Unit + def revokedLeadership(): Unit } /** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 01901bbf85d7..b443e8f0519f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -217,7 +217,7 @@ private[deploy] class Master( } override def receive: PartialFunction[Any, Unit] = { - case ElectedLeader => { + case ElectedLeader => val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE @@ -233,16 +233,14 @@ private[deploy] class Master( } }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } - } case CompleteRecovery => completeRecovery() - case RevokedLeadership => { + case RevokedLeadership => logError("Leadership has been revoked -- master shutting down.") System.exit(0) - } - case RegisterApplication(description, driver) => { + case RegisterApplication(description, driver) => // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response @@ -255,12 +253,11 @@ private[deploy] class Master( driver.send(RegisteredApplication(app.id, self)) schedule() } - } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) execOption match { - case Some(exec) => { + case Some(exec) => val appInfo = idToApp(appId) val oldState = exec.state exec.state = state @@ -298,22 +295,19 @@ private[deploy] class Master( } } } - } case None => logWarning(s"Got status update for unknown executor $appId/$execId") } - } - case DriverStateChanged(driverId, state, exception) => { + case DriverStateChanged(driverId, state, exception) => state match { case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => removeDriver(driverId, state, exception) case _ => throw new Exception(s"Received unexpected state update for driver $driverId: $state") } - } - case Heartbeat(workerId, worker) => { + case Heartbeat(workerId, worker) => idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -327,9 +321,8 @@ private[deploy] class Master( " This worker was never registered, so ignoring the heartbeat.") } } - } - case MasterChangeAcknowledged(appId) => { + case MasterChangeAcknowledged(appId) => idToApp.get(appId) match { case Some(app) => logInfo("Application has been re-registered: " + appId) @@ -339,9 +332,8 @@ private[deploy] class Master( } if (canCompleteRecovery) { completeRecovery() } - } - case WorkerSchedulerStateResponse(workerId, executors, driverIds) => { + case WorkerSchedulerStateResponse(workerId, executors, driverIds) => idToWorker.get(workerId) match { case Some(worker) => logInfo("Worker has been re-registered: " + workerId) @@ -367,7 +359,6 @@ private[deploy] class Master( } if (canCompleteRecovery) { completeRecovery() } - } case WorkerLatestState(workerId, executors, driverIds) => idToWorker.get(workerId) match { @@ -397,9 +388,8 @@ private[deploy] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case CheckForWorkerTimeOut => { + case CheckForWorkerTimeOut => timeOutDeadWorkers() - } case AttachCompletedRebuildUI(appId) => // An asyncRebuildSparkUI has completed, so need to attach to master webUi @@ -408,7 +398,7 @@ private[deploy] class Master( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => { + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -430,9 +420,8 @@ private[deploy] class Master( + workerAddress)) } } - } - case RequestSubmitDriver(description) => { + case RequestSubmitDriver(description) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + "Can only accept driver submissions in ALIVE state." @@ -451,9 +440,8 @@ private[deploy] class Master( context.reply(SubmitDriverResponse(self, true, Some(driver.id), s"Driver successfully submitted as ${driver.id}")) } - } - case RequestKillDriver(driverId) => { + case RequestKillDriver(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + s"Can only kill drivers in ALIVE state." @@ -484,9 +472,8 @@ private[deploy] class Master( context.reply(KillDriverResponse(self, driverId, success = false, msg)) } } - } - case RequestDriverStatus(driverId) => { + case RequestDriverStatus(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + "Can only request driver status in ALIVE state." @@ -501,18 +488,15 @@ private[deploy] class Master( context.reply(DriverStatusResponse(found = false, None, None, None, None)) } } - } - case RequestMasterState => { + case RequestMasterState => context.reply(MasterStateResponse( address.host, address.port, restServerBoundPort, workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state)) - } - case BoundPortsRequest => { + case BoundPortsRequest => context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) - } case RequestExecutors(appId, requestedTotal) => context.reply(handleRequestExecutors(appId, requestedTotal)) @@ -859,10 +843,10 @@ private[deploy] class Master( addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) - completedApps.take(toRemove).foreach( a => { + completedApps.take(toRemove).foreach { a => Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) - }) + } completedApps.trimStart(toRemove) } completedApps += app // Remember it in our history diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 9cd7458ba090..585e0839d0fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -78,7 +78,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { case ("--help") :: tail => printUsageAndExit(0) - case Nil => {} + case Nil => // No-op case _ => printUsageAndExit(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index dddf2be57ee4..b30bc821b732 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -40,12 +40,12 @@ abstract class PersistenceEngine { * Defines how the object is serialized and persisted. Implementation will * depend on the store used. */ - def persist(name: String, obj: Object) + def persist(name: String, obj: Object): Unit /** * Defines how the object referred by its name is removed from the store. */ - def unpersist(name: String) + def unpersist(name: String): Unit /** * Gives all objects, matching a prefix. This defines how objects are diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 79f77212fefb..af850e4871e5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -70,11 +70,10 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer try { Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(WORKING_DIR + "/" + filename) None - } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index b97805a28bdc..11e13441eeba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -76,14 +76,13 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--help") :: tail => printUsageAndExit(0) - case Nil => { + case Nil => if (masterUrl == null) { // scalastyle:off println System.err.println("--master is required") // scalastyle:on println printUsageAndExit(1) } - } case _ => printUsageAndExit(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 9c6bc5c62f25..aad2e91b2555 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -218,7 +218,7 @@ private[deploy] class DriverRunner( } private[deploy] trait Sleeper { - def sleep(seconds: Int) + def sleep(seconds: Int): Unit } // Needed because ProcessBuilder is a final class and cannot be mocked diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index f9c92c3bb9f8..06066248ea5d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -179,16 +179,14 @@ private[deploy] class ExecutorRunner( val message = "Command exited with code " + exitCode worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { - case interrupted: InterruptedException => { + case interrupted: InterruptedException => logInfo("Runner thread for executor " + fullId + " interrupted") state = ExecutorState.KILLED killProcess(None) - } - case e: Exception => { + case e: Exception => logError("Error running executor", e) state = ExecutorState.FAILED killProcess(Some(e.toString)) - } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1b7637a39ca7..449beb081117 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -480,7 +480,7 @@ private[deploy] class Worker( memoryUsed += memory_ sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { - case e: Exception => { + case e: Exception => logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) if (executors.contains(appId + "/" + execId)) { executors(appId + "/" + execId).kill() @@ -488,7 +488,6 @@ private[deploy] class Worker( } sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(e.toString), None)) - } } } @@ -509,7 +508,7 @@ private[deploy] class Worker( } } - case LaunchDriver(driverId, driverDesc) => { + case LaunchDriver(driverId, driverDesc) => logInfo(s"Asked to launch driver $driverId") val driver = new DriverRunner( conf, @@ -525,9 +524,8 @@ private[deploy] class Worker( coresUsed += driverDesc.cores memoryUsed += driverDesc.mem - } - case KillDriver(driverId) => { + case KillDriver(driverId) => logInfo(s"Asked to kill driver $driverId") drivers.get(driverId) match { case Some(runner) => @@ -535,11 +533,9 @@ private[deploy] class Worker( case None => logError(s"Asked to kill unknown driver $driverId") } - } - case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => handleDriverStateChanged(driverStateChanged) - } case ReregisterWithMaster => reregisterWithMaster() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 391eb4119092..777020d4d5c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -165,12 +165,11 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } // scalastyle:on classforname } catch { - case e: Exception => { + case e: Exception => totalMb = 2*1024 // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") // scalastyle:on println - } } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index d4ed5845e747..71b4ad160d67 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -62,10 +62,9 @@ private[spark] class CoarseGrainedExecutorBackend( // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => // Always receive `true`. Just ignore it - case Failure(e) => { + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) System.exit(1) - } }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 09c57335650c..9f94fdef24eb 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer +import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ @@ -206,9 +207,16 @@ private[spark] class Executor( startGCTime = computeTotalGcTime() try { - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + val (taskFiles, taskJars, taskProps, taskBytes) = + Task.deserializeWithDependencies(serializedTask) + + // Must be set before updateDependencies() is called, in case fetching dependencies + // requires access to properties contained within (e.g. for access control). + Executor.taskDeserializationProps.set(taskProps) + updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.localProperties = taskProps task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, @@ -321,7 +329,7 @@ private[spark] class Executor( logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - case cDE: CommitDeniedException => + case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -506,3 +514,10 @@ private[spark] class Executor( heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } + +private[spark] object Executor { + // This is reserved for internal use by components that need to read task properties before a + // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be + // used instead. + val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties] +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index e07cb31cbe4b..7153323d01a0 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -25,6 +25,6 @@ import org.apache.spark.TaskState.TaskState * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) + def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit } diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 6d30d3c76a9f..83e11c5e236d 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -81,35 +81,9 @@ class InputMetrics private ( */ def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) - // Once incBytesRead & intRecordsRead is ready to be removed from the public API - // we can remove the internal versions and make the previous public API private. - // This has been done to suppress warnings when building. - @deprecated("incrementing input metrics is for internal use only", "2.0.0") - def incBytesRead(v: Long): Unit = _bytesRead.add(v) - private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v) - @deprecated("incrementing input metrics is for internal use only", "2.0.0") - def incRecordsRead(v: Long): Unit = _recordsRead.add(v) - private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v) + private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v) + private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) - private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = - _readMethod.setValue(v.toString) + private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) } - -/** - * Deprecated methods to preserve case class matching behavior before Spark 2.0. - */ -object InputMetrics { - - @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") - def apply(readMethod: DataReadMethod.Value): InputMetrics = { - val im = new InputMetrics - im.setReadMethod(readMethod) - im - } - - @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") - def unapply(input: InputMetrics): Option[DataReadMethod.Value] = { - Some(input.readMethod) - } -} diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index 0b37d559c746..93f953846fe2 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -51,18 +51,6 @@ class OutputMetrics private ( TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD)) } - /** - * Create a new [[OutputMetrics]] that is not associated with any particular task. - * - * This is only used for preserving matching behavior on [[OutputMetrics]], which used to be - * a case class before Spark 2.0. Once we remove support for matching on [[OutputMetrics]] - * we can remove this constructor as well. - */ - private[executor] def this() { - this(InternalAccumulator.createOutputAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) - } - /** * Total number of bytes written. */ @@ -84,21 +72,3 @@ class OutputMetrics private ( _writeMethod.setValue(v.toString) } - -/** - * Deprecated methods to preserve case class matching behavior before Spark 2.0. - */ -object OutputMetrics { - - @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") - def apply(writeMethod: DataWriteMethod.Value): OutputMetrics = { - val om = new OutputMetrics - om.setWriteMethod(writeMethod) - om - } - - @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") - def unapply(output: OutputMetrics): Option[DataWriteMethod.Value] = { - Some(output.writeMethod) - } -} diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 50bb645d974a..71a24770b50a 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -116,4 +116,25 @@ class ShuffleReadMetrics private ( private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) + /** + * Resets the value of the current metrics (`this`) and and merges all the independent + * [[ShuffleReadMetrics]] into `this`. + */ + private[spark] def setMergeValues(metrics: Seq[ShuffleReadMetrics]): Unit = { + _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero) + _localBlocksFetched.setValue(_localBlocksFetched.zero) + _remoteBytesRead.setValue(_remoteBytesRead.zero) + _localBytesRead.setValue(_localBytesRead.zero) + _fetchWaitTime.setValue(_fetchWaitTime.zero) + _recordsRead.setValue(_recordsRead.zero) + metrics.foreach { metric => + _remoteBlocksFetched.add(metric.remoteBlocksFetched) + _localBlocksFetched.add(metric.localBlocksFetched) + _remoteBytesRead.add(metric.remoteBytesRead) + _localBytesRead.add(metric.localBytesRead) + _fetchWaitTime.add(metric.fetchWaitTime) + _recordsRead.add(metric.recordsRead) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 02219a84abfd..bda2a91d9d2c 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -139,16 +139,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue - @deprecated("use updatedBlockStatuses instead", "2.0.0") - def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { - if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None - } - - @deprecated("setting updated blocks is not allowed", "2.0.0") - def updatedBlocks_=(blocks: Option[Seq[(BlockId, BlockStatus)]]): Unit = { - blocks.foreach(setUpdatedBlockStatuses) - } - // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) @@ -225,11 +215,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ def outputMetrics: Option[OutputMetrics] = _outputMetrics - @deprecated("setting OutputMetrics is for internal use only", "2.0.0") - def outputMetrics_=(om: Option[OutputMetrics]): Unit = { - _outputMetrics = om - } - /** * Get or create a new [[OutputMetrics]] associated with this task. */ @@ -285,12 +270,7 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { if (tempShuffleReadMetrics.nonEmpty) { val metrics = new ShuffleReadMetrics(initialAccumsMap) - metrics.setRemoteBlocksFetched(tempShuffleReadMetrics.map(_.remoteBlocksFetched).sum) - metrics.setLocalBlocksFetched(tempShuffleReadMetrics.map(_.localBlocksFetched).sum) - metrics.setFetchWaitTime(tempShuffleReadMetrics.map(_.fetchWaitTime).sum) - metrics.setRemoteBytesRead(tempShuffleReadMetrics.map(_.remoteBytesRead).sum) - metrics.setLocalBytesRead(tempShuffleReadMetrics.map(_.localBytesRead).sum) - metrics.setRecordsRead(tempShuffleReadMetrics.map(_.recordsRead).sum) + metrics.setMergeValues(tempShuffleReadMetrics) _shuffleReadMetrics = Some(metrics) } } @@ -306,11 +286,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics - @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0") - def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = { - _shuffleWriteMetrics = swm - } - /** * Get or create a new [[ShuffleWriteMetrics]] associated with this task. */ diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4da1017d282e..0fed991049dd 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -196,10 +196,9 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => { + case e: Exception => logError("Sink class " + classPath + " cannot be instantiated") throw e - } } } } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index e43e3a2de256..09ce012e4e69 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -36,7 +36,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch * local blocks or put local blocks. */ - def init(blockDataManager: BlockDataManager) + def init(blockDataManager: BlockDataManager): Unit /** * Tear down the transfer service. diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 5f3d4532dd86..33a321960774 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -39,7 +39,11 @@ import org.apache.spark.util.Utils /** * A BlockTransferService that uses Netty to fetch a set of blocks at at time. */ -class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int) +private[spark] class NettyBlockTransferService( + conf: SparkConf, + securityManager: SecurityManager, + override val hostName: String, + numCores: Int) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. @@ -65,13 +69,13 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) server = createServer(serverBootstrap.toList) appId = conf.getAppId - logInfo("Server created on " + server.getPort) + logInfo(s"Server created on ${hostName}:${server.getPort}") } /** Creates and binds the TransportServer, possibly trying multiple ports. */ private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = { def startService(port: Int): (TransportServer, Int) = { - val server = transportContext.createServer(port, bootstraps.asJava) + val server = transportContext.createServer(hostName, port, bootstraps.asJava) (server, server.getPort) } @@ -109,8 +113,6 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage } } - override def hostName: String = Utils.localHostName() - override def port: Int = server.getPort override def uploadBlock( diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index c562c70aba4f..ab6aba6fc7d6 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -32,12 +32,11 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v */ override def equals(that: Any): Boolean = that match { - case that: BoundedDouble => { + case that: BoundedDouble => this.mean == that.mean && this.confidence == that.confidence && this.low == that.low && this.high == that.high - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 8358244987a6..63d1d1767a8c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -35,9 +35,9 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.length).map(i => { + (0 until blockIds.length).map { i => new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 5e9230e7337c..368916a39e64 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -166,8 +166,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val counters = new Array[Long](buckets.length - 1) while (iter.hasNext) { bucketFunction(iter.next()) match { - case Some(x: Int) => {counters(x) += 1} - case _ => {} + case Some(x: Int) => counters(x) += 1 + case _ => // No-Op } } Iterator(counters) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 08db96edd69b..35d190b464ff 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,15 +213,13 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDDState.unsetInputFileName() + case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) + case _ => InputFileNameHolder.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -261,7 +259,7 @@ class HadoopRDD[K, V]( finished = true } if (!finished) { - inputMetrics.incRecordsReadInternal(1) + inputMetrics.incRecordsRead(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -271,7 +269,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic @@ -293,7 +291,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength) + inputMetrics.incBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) @@ -424,7 +422,7 @@ private[spark] object HadoopRDD extends Logging { private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { val out = ListBuffer[String]() - infos.foreach { loc => { + infos.foreach { loc => val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. getLocation.invoke(loc).asInstanceOf[String] if (locationStr != "localhost") { @@ -436,7 +434,7 @@ private[spark] object HadoopRDD extends Logging { out += new HostTaskLocation(locationStr).toString } } - }} + } out.seq } } diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala similarity index 86% rename from core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala rename to core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala index 3f15fff79366..108e9d255819 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -20,10 +20,10 @@ package org.apache.spark.rdd import org.apache.spark.unsafe.types.UTF8String /** - * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. - * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD and InputFileName function in Spark SQL. */ -private[spark] object SqlNewHadoopRDDState { +private[spark] object InputFileNameHolder { /** * The thread variable for the name of the current file being read. This is used by * the InputFileName function in Spark SQL. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index fb9606ae388d..3ccd616cbfd5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -189,7 +189,7 @@ class NewHadoopRDD[K, V]( } havePair = false if (!finished) { - inputMetrics.incRecordsReadInternal(1) + inputMetrics.incRecordsRead(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -220,7 +220,7 @@ class NewHadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 363004e587f2..a5992022d083 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -86,12 +86,11 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) val rddToFilter: RDD[P] = self.partitioner match { - case Some(rp: RangePartitioner[K, V]) => { + case Some(rp: RangePartitioner[K, V]) => val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { case (l, u) => Math.min(l, u) to Math.max(l, u) } PartitionPruningRDD.create(self, partitionIndicies.contains) - } case _ => self } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 296179b75bc4..085829af6eee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1111,9 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - } { - writer.close(hadoopContext) - } + }(finallyBlock = writer.close(hadoopContext)) committer.commitTask(hadoopContext) outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => om.setBytesWritten(callback()) @@ -1200,9 +1198,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } - } { - writer.close() - } + }(finallyBlock = writer.close()) writer.commit() outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => om.setBytesWritten(callback()) diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 582fa93afe34..bb84e4af15b1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -121,14 +121,14 @@ private object ParallelCollectionRDD { // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { - (0 until numSlices).iterator.map(i => { + (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) - }) + } } seq match { - case r: Range => { + case r: Range => positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { @@ -138,8 +138,7 @@ private object ParallelCollectionRDD { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } }).toSeq.asInstanceOf[Seq[Seq[T]]] - } - case nr: NumericRange[_] => { + case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr @@ -149,14 +148,12 @@ private object ParallelCollectionRDD { r = r.drop(sliceSize) } slices - } - case _ => { + case _ => val array = seq.toArray // To prevent O(n^2) operations for List etc positions(array.length, numSlices).map({ case (start, end) => array.slice(start, end).toSeq }).toSeq - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a79..0abba15bec9f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -68,9 +68,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { val numPartitions = partitioner.get.numPartitions - (0 until numPartitions).map(index => { + (0 until numPartitions).map { index => new PartitionerAwareUnionRDDPartition(rdds, index) - }).toArray + }.toArray } // Get the location where most of the partitions of parent RDDs are located @@ -78,11 +78,10 @@ class PartitionerAwareUnionRDD[T: ClassTag]( logDebug("Finding preferred location for " + this + ", partition " + s.index) val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents val locations = rdds.zip(parentPartitions).flatMap { - case (rdd, part) => { + case (rdd, part) => val parentLocations = currPrefLocs(rdd, part) logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) parentLocations - } } val location = if (locations.isEmpty) { None diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 032939b49a70..36ff3bcaaec6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -333,10 +333,10 @@ abstract class RDD[T: ClassTag]( case Left(blockResult) => if (readCachedBlock) { val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) - existingMetrics.incBytesReadInternal(blockResult.bytes) + existingMetrics.incBytesRead(blockResult.bytes) new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) { override def next(): T = { - existingMetrics.incRecordsReadInternal(1) + existingMetrics.incRecordsRead(1) delegate.next() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5cdc91316b69..c27aad268d32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -950,13 +950,6 @@ class DAGScheduler( // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() - // Create internal accumulators if the stage has no accumulators initialized. - // Reset internal accumulators only if this stage is not partially submitted - // Otherwise, we may override existing accumulator values from some tasks - if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) { - stage.resetInternalAccumulators() - } - // Use the scheduling pool, job group, description, etc. from an ActiveJob associated // with this Stage val properties = jobIdToActiveJob(jobId).properties @@ -1036,7 +1029,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.internalAccumulators) + taskBinary, part, locs, stage.latestInfo.internalAccumulators, properties) } case stage: ResultStage => @@ -1046,7 +1039,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, stage.internalAccumulators) + taskBinary, part, locs, id, properties, stage.latestInfo.internalAccumulators) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0640f2605143..a6b032cc0084 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -57,11 +57,10 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // Since we are not doing canonicalization of path, this can be wrong : like relative vs // absolute path .. which is fine, this is best case effort to remove duplicates - right ? override def equals(other: Any): Boolean = other match { - case that: InputFormatInfo => { + case that: InputFormatInfo => // not checking config - that should be fine, right ? this.inputFormatClazz == that.inputFormatClazz && this.path == that.path - } case _ => false } @@ -86,10 +85,9 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl } } catch { - case e: ClassNotFoundException => { + case e: ClassNotFoundException => throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index 50c2b9acd609..e0f7c8f02132 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -23,6 +23,6 @@ package org.apache.spark.scheduler * job fails (and no further taskSucceeded events will happen). */ private[spark] trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) + def taskSucceeded(index: Int, result: Any): Unit + def jobFailed(exception: Exception): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index cd2736e1960c..db6276f75d78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer +import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast @@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). + * @param localProperties copy of thread-local properties set by the user on the driver side. * @param _initialAccums initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. @@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, + localProperties: Properties, _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) - extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 5baebe8c1ff8..100ed76ecb6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -34,9 +34,9 @@ import org.apache.spark.util.Utils private[spark] trait SchedulableBuilder { def rootPool: Pool - def buildPools() + def buildPools(): Unit - def addTaskSetManager(manager: Schedulable, properties: Properties) + def addTaskSetManager(manager: Schedulable, properties: Properties): Unit } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index e30964a01bda..b7cab7013ef6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties import scala.language.existentials @@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter * @param _initialAccums initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - _initialAccums: Seq[Accumulator[_]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums) + _initialAccums: Seq[Accumulator[_]], + localProperties: Properties) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index 6e9337bb9063..bc1431835e25 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -49,14 +49,13 @@ class SplitInfo( // So unless there is identity equality between underlyingSplits, it will always fail even if it // is pointing to same block. override def equals(other: Any): Boolean = other match { - case that: SplitInfo => { + case that: SplitInfo => this.hostLocation == that.hostLocation && this.inputFormatClazz == that.inputFormatClazz && this.path == that.path && this.length == that.length && // other split specific checks (like start for FileSplit) this.underlyingSplit == that.underlyingSplit - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index a40b700cdd35..b6d4e39fe532 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -75,22 +75,6 @@ private[scheduler] abstract class Stage( val name: String = callSite.shortForm val details: String = callSite.longForm - private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty - - /** Internal accumulators shared across all tasks in this stage. */ - def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators - - /** - * Re-initialize the internal accumulators associated with this stage. - * - * This is called every time the stage is submitted, *except* when a subset of tasks - * belonging to this stage has already finished. Otherwise, reinitializing the internal - * accumulators here again will override partial values from the finished tasks. - */ - def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) - } - /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this @@ -127,7 +111,8 @@ private[scheduler] abstract class Stage( numPartitionsToCompute: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { _latestInfo = StageInfo.fromStage( - this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) + this, nextAttemptId, Some(numPartitionsToCompute), + InternalAccumulator.createAll(rdd.sparkContext), taskLocalityPreferences) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 24796c14300b..0fd58c41cdce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashMap +import org.apache.spark.Accumulator import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -35,6 +36,7 @@ class StageInfo( val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], val details: String, + val internalAccumulators: Seq[Accumulator[_]] = Seq.empty, private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None @@ -42,7 +44,11 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None - /** Terminal values of accumulables updated during this stage. */ + + /** + * Terminal values of accumulables updated during this stage, including all the user-defined + * accumulators. + */ val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { @@ -75,6 +81,7 @@ private[spark] object StageInfo { stage: Stage, attemptId: Int, numTasks: Option[Int] = None, + internalAccumulators: Seq[Accumulator[_]] = Seq.empty, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) @@ -87,6 +94,7 @@ private[spark] object StageInfo { rddInfos, stage.parents.map(_.id), stage.details, + internalAccumulators, taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 46c64f61de5f..1ff9d7795f42 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.HashMap @@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * @param initialAccumulators initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]], + @transient var localProperties: Properties) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -71,6 +74,7 @@ private[spark] abstract class Task[T]( taskAttemptId, attemptNumber, taskMemoryManager, + localProperties, metricsSystem, initialAccumulators) TaskContext.setTaskContext(context) @@ -80,10 +84,16 @@ private[spark] abstract class Task[T]( } try { runTask(context) - } catch { case e: Throwable => - // Catch all errors; run task failure callbacks, and rethrow the exception. - context.markTaskFailed(e) - throw e + } catch { + case e: Throwable => + // Catch all errors; run task failure callbacks, and rethrow the exception. + try { + context.markTaskFailed(e) + } catch { + case t: Throwable => + e.addSuppressed(t) + } + throw e } finally { // Call the task completion callbacks. context.markTaskCompleted() @@ -206,6 +216,11 @@ private[spark] object Task { dataOut.writeLong(timestamp) } + // Write the task properties separately so it is available before full task deserialization. + val propBytes = Utils.serialize(task.localProperties) + dataOut.writeInt(propBytes.length) + dataOut.write(propBytes) + // Write the task itself and finish dataOut.flush() val taskBytes = serializer.serialize(task) @@ -221,7 +236,7 @@ private[spark] object Task { * @return (taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) val dataIn = new DataInputStream(in) @@ -240,8 +255,13 @@ private[spark] object Task { taskJars(dataIn.readUTF()) = dataIn.readLong() } + val propLength = dataIn.readInt() + val propBytes = new Array[Byte](propLength) + dataIn.readFully(propBytes, 0, propLength) + val taskProps = Utils.deserialize[Properties](propBytes) + // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (taskFiles, taskJars, taskProps, subBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 873f1b56bd18..ae7ef46abbf3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -133,7 +133,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul // if we can't deserialize the reason. logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + case ex: Exception => // No-op } scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 8477a66b394f..647d44a0f068 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -51,7 +51,7 @@ private[spark] trait TaskScheduler { def submitTasks(taskSet: TaskSet): Unit // Cancel a stage. - def cancelTasks(stageId: Int, interruptThread: Boolean) + def cancelTasks(stageId: Int, interruptThread: Boolean): Unit // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index daed2ff50e15..c3159188d9f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -571,6 +571,11 @@ private[spark] class TaskSchedulerImpl( return } while (!backend.isReady) { + // Might take a while for backend to be ready if it is waiting on resources. + if (sc.stopped.get) { + // For example: the master removes the application for some reason + throw new IllegalStateException("Spark context stopped while waiting for backend") + } synchronized { this.wait(100) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 15d3515a02b3..6e08cdd87a8d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -188,20 +188,18 @@ private[spark] class TaskSetManager( loc match { case e: ExecutorCacheTaskLocation => pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index - case e: HDFSCacheTaskLocation => { + case e: HDFSCacheTaskLocation => val exe = sched.getExecutorsAliveOnHost(loc.host) exe match { - case Some(set) => { + case Some(set) => for (e <- set) { pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index } logInfo(s"Pending task $index has a cached location at ${e.host} " + ", where there are executors " + set.mkString(",")) - } case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + ", but there are no executors alive there.") } - } case _ => } pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index @@ -437,7 +435,7 @@ private[spark] class TaskSetManager( } dequeueTask(execId, host, allowedLocality) match { - case Some((index, taskLocality, speculative)) => { + case Some((index, taskLocality, speculative)) => // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() @@ -486,7 +484,6 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, taskName, index, serializedTask)) - } case _ => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3971e6c3826c..61ab3e87c571 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -121,11 +121,10 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( Some(Utils.deserialize[T](fileData)) } catch { case e: NoNodeException => None - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(zkPath) None - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 374c79a7e5ac..1b7ac172defb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -55,11 +55,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(vol.setContainerPath(container_path) .setHostPath(host_path) .setMode(Volume.Mode.RO)) - case spec => { + case spec => logWarning(s"Unable to parse volume specs: $volumes. " + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") None - } } } .map { _.build() } @@ -90,11 +89,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(portmap.setHostPort(host_port.toInt) .setContainerPort(container_port.toInt) .setProtocol(protocol)) - case spec => { + case spec => logWarning(s"Unable to parse port mapping specs: $portmaps. " + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") None - } } } .map { _.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 233bdc23e647..1e322ac67941 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -124,11 +124,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { markErr() } } catch { - case e: Exception => { + case e: Exception => logError("driver.run() failed", e) error = Some(e) markErr() - } } } }.start() @@ -184,7 +183,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { - case r => { + case r => if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && @@ -196,7 +195,6 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } else { r } - } } // Filter any resource that has depleted. @@ -228,7 +226,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.asScala.map(attr => { + offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -236,7 +234,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { case Value.Type.TEXT => attr.getText } (attr.getName, attrValue) - }).toMap + }.toMap } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3d090a4353c3..918ae376f628 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -357,7 +357,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ * serialization. */ trait KryoRegistrator { - def registerClasses(kryo: Kryo) + def registerClasses(kryo: Kryo): Unit } private[serializer] object KryoSerializer { diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 5ead40e89e29..cb95246d5b0c 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -188,10 +188,9 @@ abstract class DeserializationStream { try { (readKey[Any](), readValue[Any]()) } catch { - case eof: EOFException => { + case eof: EOFException => finished = true null - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 637b2dfc193b..876cdfaa8760 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -69,10 +69,10 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Update the context task metrics for each record read. val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map(record => { + recordIter.map { record => readMetrics.incRecordsRead(1) record - }), + }, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6cd7d6951851..be1e84a2ba93 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ private[spark] trait ShuffleWriterGroup { val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ - def releaseWriters(success: Boolean) + def releaseWriters(success: Boolean): Unit } /** diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 9c92a501503c..f8d6e9fbbb90 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -147,7 +147,7 @@ private[v1] object AllStagesResource { speculative = uiData.taskInfo.speculative, accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, errorMessage = uiData.errorMessage, - taskMetrics = uiData.taskMetrics.map { convertUiTaskMetrics } + taskMetrics = uiData.metrics.map { convertUiTaskMetrics } ) } @@ -155,7 +155,7 @@ private[v1] object AllStagesResource { allTaskData: Iterable[TaskUIData], quantiles: Array[Double]): TaskMetricDistributions = { - val rawMetrics = allTaskData.flatMap{_.taskMetrics}.toSeq + val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] = Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 25edb9f1e4c2..4ec5b4bbb07c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -143,13 +143,12 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => { + case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() - } case _ => } } @@ -313,7 +312,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => { + case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) @@ -323,7 +322,6 @@ final class ShuffleBlockFetcherIterator( reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } - } case _ => } // Send fetch requests up to maxBytesInFlight diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c3c59f857dc4..119165f724f5 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -30,6 +30,7 @@ import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.server.nio.SelectChannelConnector import org.eclipse.jetty.server.ssl.SslSelectChannelConnector import org.eclipse.jetty.servlet._ +import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.QueuedThreadPool import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} @@ -350,4 +351,15 @@ private[spark] object JettyUtils extends Logging { private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) + rootHandler: ContextHandlerCollection) { + + def stop(): Unit = { + server.stop() + // Stop the ThreadPool if it supports stop() method (through LifeCycle). + // It is needed because stopping the Server won't stop the ThreadPool it uses. + val threadPool = server.getThreadPool + if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { + threadPool.asInstanceOf[LifeCycle].stop + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 250b7f2e5f48..2b0bc32cf655 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -129,7 +129,7 @@ private[spark] abstract class WebUI( } /** Initialize all components of the server. */ - def initialize() + def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ def bind() { @@ -153,7 +153,7 @@ private[spark] abstract class WebUI( def stop() { assert(serverInfo.isDefined, "Attempted to stop %s before binding to a server!".format(className)) - serverInfo.get.server.stop() + serverInfo.get.stop() } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index cc476d61b569..a0ef80d9bdae 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -38,7 +38,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val content = maybeThreadDump.map { threadDump => val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => { + case (threadTrace1, threadTrace2) => val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { @@ -46,7 +46,6 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } else { v1 > v2 } - } }.map { thread => val threadId = thread.threadId { + stageData.foreach { data => hasInput = data.hasInput hasOutput = data.hasOutput hasShuffleRead = data.hasShuffleRead hasShuffleWrite = data.hasShuffleWrite hasBytesSpilled = data.hasBytesSpilled - }) + } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 645e2d2e360b..bd4797ae8e0c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -203,7 +203,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { // This could be empty if the JobProgressListener hasn't received information about the // stage or if the stage information has been garbage collected listener.stageIdToInfo.getOrElse(stageId, - new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown")) + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown", Seq.empty)) } val activeStages = Buffer[StageInfo]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index ed3ab66e3b68..13f5f84d06fe 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -396,13 +396,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { None } taskMetrics.foreach { m => - val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics) updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) taskData.taskInfo = info - taskData.taskMetrics = taskMetrics + taskData.metrics = taskMetrics taskData.errorMessage = errorMessage for ( @@ -506,9 +506,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates) taskData.foreach { t => if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics) + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics) // Overwrite task metrics - t.taskMetrics = Some(metrics) + t.metrics = Some(metrics) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 689ab7dd5ed6..8a44bbd9fcd5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -330,7 +330,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else taskTable.dataSource.slicedTaskIds // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) + val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = if (validTasks.size == 0) { @@ -348,8 +348,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getDistributionQuantiles(data).map(d => ) } - val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.executorDeserializeTime.toDouble + val deserializationTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.executorDeserializeTime.toDouble } val deserializationQuantiles = +: getFormattedTimeQuantiles(deserializationTimes) - val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.executorRunTime.toDouble + val serviceTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.executorRunTime.toDouble } val serviceQuantiles = +: getFormattedTimeQuantiles(serviceTimes) - val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.jvmGCTime.toDouble + val gcTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.jvmGCTime.toDouble } val gcQuantiles = +: getFormattedTimeQuantiles(gcTimes) - val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.resultSerializationTime.toDouble + val serializationTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.resultSerializationTime.toDouble } val serializationQuantiles = +: getFormattedTimeQuantiles(serializationTimes) - val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info, currentTime).toDouble + val gettingResultTimes = validTasks.map { taskUIData: TaskUIData => + getGettingResultTime(taskUIData.taskInfo, currentTime).toDouble } val gettingResultQuantiles = +: getFormattedTimeQuantiles(gettingResultTimes) - val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.peakExecutionMemory.toDouble + val peakExecutionMemory = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.peakExecutionMemory.toDouble } val peakExecutionMemoryQuantiles = { @@ -427,30 +427,30 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ) } - val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble + val inputSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + val inputRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble } val inputQuantiles = +: getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) - val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + val outputSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + val outputRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val outputQuantiles = +: getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) - val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble } val shuffleReadBlockedQuantiles = +: getFormattedTimeQuantiles(shuffleReadBlockedTimes) - val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble } - val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble } val shuffleReadTotalQuantiles = +: getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadRemoteQuantiles = +: getFormattedSizeQuantiles(shuffleReadRemoteSizes) - val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val shuffleWriteQuantiles = +: getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) - val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.memoryBytesSpilled.toDouble + val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles = +: getFormattedSizeQuantiles(memoryBytesSpilledSizes) - val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.diskBytesSpilled.toDouble + val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles = +: getFormattedSizeQuantiles(diskBytesSpilledSizes) @@ -601,7 +601,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 - val metricsOpt = taskUIData.taskMetrics + val metricsOpt = taskUIData.metrics val shuffleReadTime = metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L) val shuffleReadTimeProportion = toProportion(shuffleReadTime) @@ -868,7 +868,8 @@ private[ui] class TaskDataSource( def slicedTaskIds: Set[Long] = _slicedTaskIds private def taskRow(taskData: TaskUIData): TaskTableRowData = { - val TaskUIData(info, metrics, errorMessage) = taskData + val info = taskData.taskInfo + val metrics = taskData.metrics val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) @@ -1014,7 +1015,7 @@ private[ui] class TaskDataSource( shuffleRead, shuffleWrite, bytesSpilled, - errorMessage.getOrElse("")) + taskData.errorMessage.getOrElse("")) } /** diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 78165d7b743e..b454ef1b204b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -105,12 +105,12 @@ private[spark] object UIData { /** * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. */ - case class TaskUIData( + class TaskUIData( var taskInfo: TaskInfo, - var taskMetrics: Option[TaskMetrics] = None, + var metrics: Option[TaskMetrics] = None, var errorMessage: Option[String] = None) - case class ExecutorUIData( + class ExecutorUIData( val startTime: Long, var finishTime: Option[Long] = None, var finishReason: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/util/CausedBy.scala b/core/src/main/scala/org/apache/spark/util/CausedBy.scala new file mode 100644 index 000000000000..73df446d981c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CausedBy.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Extractor Object for pulling out the root cause of an error. + * If the error contains no cause, it will return the error itself. + * + * Usage: + * try { + * ... + * } catch { + * case CausedBy(ex: CommitDeniedException) => ... + * } + */ +private[spark] object CausedBy { + + def unapply(e: Throwable): Option[Throwable] = { + Option(e.getCause).flatMap(cause => unapply(cause)).orElse(Some(e)) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 2f6924f7deef..489688cb0880 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -19,7 +19,8 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import scala.collection.mutable.{Map, Set} +import scala.collection.mutable.{Map, Set, Stack} +import scala.language.existentials import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} import org.apache.xbean.asm5.Opcodes._ @@ -77,35 +78,19 @@ private[spark] object ClosureCleaner extends Logging { */ private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) + val stack = Stack[Class[_]](obj.getClass) while (!stack.isEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail + val cr = getClassReader(stack.pop()) val set = Set[Class[_]]() cr.accept(new InnerClosureFinder(set), 0) for (cls <- set -- seen) { seen += cls - stack = cls :: stack + stack.push(cls) } } (seen - obj.getClass).toList } - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - cls match { - case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\u0000') - case java.lang.Void.TYPE => - // This should not happen because `Foo(void x) {}` does not compile. - throw new IllegalStateException("Unexpected void parameter in constructor") - case _ => new java.lang.Byte(0: Byte) - } - } else { - null - } - } - /** * Clean the given closure in place. * @@ -233,16 +218,24 @@ private[spark] object ClosureCleaner extends Logging { // Note that all outer objects but the outermost one (first one in this list) must be closures var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse var parent: AnyRef = null - if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}") - parent = outerPairs.head._2 // e.g. SparkContext - outerPairs = outerPairs.tail - } else if (outerPairs.size > 0) { - logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}") + if (outerPairs.size > 0) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it + // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { logDebug(" + there are no enclosing objects!") } diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index 153025cef247..3ea9139e1102 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -47,13 +47,12 @@ private[spark] abstract class EventLoop[E](name: String) extends Logging { try { onReceive(event) } catch { - case NonFatal(e) => { + case NonFatal(e) => try { onError(e) } catch { case NonFatal(e) => logError("Unexpected error in " + name, e) } - } } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 09d955300a64..558767e36f7d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -578,7 +578,9 @@ private[spark] object JsonProtocol { // The "Stage Infos" field was added in Spark 1.2.0 val stageInfos = Utils.jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { - stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) + stageIds.map { id => + new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", Seq.empty) + } } SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) } @@ -686,7 +688,7 @@ private[spark] object JsonProtocol { } val stageInfo = new StageInfo( - stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details) + stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, Seq.empty) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason @@ -811,8 +813,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) val inputMetrics = metrics.registerInputMetrics(readMethod) - inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) } // Updated blocks diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 3f627a01453e..6861a75612dd 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -151,13 +151,12 @@ object SizeEstimator extends Logging { // TODO: We could use reflection on the VMOption returned ? getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { - case e: Exception => { + case e: Exception => // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) val guessInWords = if (guess) "yes" else "not" logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) return guess - } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index c304629bcdbe..78e164cff773 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1260,26 +1260,35 @@ private[spark] object Utils extends Logging { } /** - * Execute a block of code, call the failure callbacks before finally block if there is any - * exceptions happen. But if exceptions happen in the finally block, do not suppress the original - * exception. + * Execute a block of code and call the failure callbacks in the catch block. If exceptions occur + * in either the catch or the finally block, they are appended to the list of suppressed + * exceptions in original exception which is then rethrown. * - * This is primarily an issue with `finally { out.close() }` blocks, where - * close needs to be called to clean up `out`, but if an exception happened - * in `out.write`, it's likely `out` may be corrupted and `out.close` will + * This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks, + * where the abort/close needs to be called to clean up `out`, but if an exception happened + * in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will * fail as well. This would then suppress the original/likely more meaningful * exception from the original `out.write` call. */ - def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = { + def tryWithSafeFinallyAndFailureCallbacks[T](block: => T) + (catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = { var originalThrowable: Throwable = null try { block } catch { - case t: Throwable => + case cause: Throwable => // Purposefully not using NonFatal, because even fatal exceptions // we don't want to have our finallyBlock suppress - originalThrowable = t - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t) + originalThrowable = cause + try { + logError("Aborting task", originalThrowable) + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable) + catchBlock + } catch { + case t: Throwable => + originalThrowable.addSuppressed(t) + logWarning(s"Suppressing exception in catch: " + t.getMessage, t) + } throw originalThrowable } finally { try { diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index b34880d3a748..6e80db2f51f9 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -32,10 +32,10 @@ private[spark] trait RollingPolicy { def shouldRollover(bytesToBeWritten: Long): Boolean /** Notify that rollover has occurred */ - def rolledOver() + def rolledOver(): Unit /** Notify that bytes have been written */ - def bytesWritten(bytes: Long) + def bytesWritten(bytes: Long): Unit /** Get the desired name of the rollover file */ def generateRolledOverFileSuffix(): String diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala index 70f3dd62b9b1..41f28f6e511e 100644 --- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala @@ -26,5 +26,5 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi trait Pseudorandom { /** Set random seed. */ - def setSeed(seed: Long) + def setSeed(seed: Long): Unit } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index ec192a8543ae..37879d11caec 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import java.util.Properties import java.util.concurrent.Semaphore import javax.annotation.concurrent.GuardedBy @@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance) // Now we're on the executors. // Deserialize the task and assert that its accumulators are zero'ed out. - val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer) + val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer) val taskDeser = serInstance.deserialize[DummyTask]( taskBytes, Thread.currentThread.getContextClassLoader) // Assert that executors see only zeros @@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener { private[spark] class DummyTask( val internalAccums: Seq[Accumulator[_]], val externalAccums: Seq[Accumulator[_]]) - extends Task[Int](0, 0, 0, internalAccums) { + extends Task[Int](0, 0, 0, internalAccums, new Properties) { override def runTask(c: TaskContext): Int = 1 } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 67d722c1dc15..2110d3d770d5 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -320,7 +320,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 80a1de6065b4..ee6b99146190 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -928,8 +928,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { numTasks: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { - new StageInfo( - stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) + new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", + Seq.empty, taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 6ffa1c8ac140..cd7d2e15700d 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import java.util.Properties import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers @@ -335,16 +336,16 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem, - InternalAccumulator.create(sc))) + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem, + InternalAccumulator.createAll(sc))) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem, - InternalAccumulator.create(sc))) + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem, + InternalAccumulator.createAll(sc))) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -372,8 +373,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem, - InternalAccumulator.create(sc))) + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem, + InternalAccumulator.createAll(sc))) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 3706455c3fac..8feb3dee050d 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -82,20 +82,18 @@ package object testPackage extends Assertions { val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt - } case _ => fail("Did not match expected call site format") } curCallSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) - } case _ => fail("Did not match expected call site format") } } diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index f7a13ab3996d..09e21646ee74 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -35,7 +35,7 @@ class UnpersistSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 088b05403c1a..d91f50f18f43 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -285,8 +285,8 @@ class TaskMetricsSuite extends SparkFunSuite { // set and increment values in.setBytesRead(1L) in.setBytesRead(2L) - in.incRecordsReadInternal(1L) - in.incRecordsReadInternal(2L) + in.incRecordsRead(1L) + in.incRecordsRead(2L) in.setReadMethod(DataReadMethod.Disk) // assert new values exist assertValEquals(_.bytesRead, BYTES_READ, 2L) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 2b5e4b80e96a..362cd861cc24 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.memory +import java.util.Properties + import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} /** @@ -31,6 +33,7 @@ object MemoryTestingUtils { taskAttemptId = 0, attemptNumber = 0, taskMemoryManager = taskMemoryManager, + localProperties = new Properties, metricsSystem = env.metricsSystem) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 6da18cfd4972..ed15e77ff142 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -108,11 +108,11 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) - val exec0 = new NettyBlockTransferService(conf0, securityManager0, numCores = 1) + val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", numCores = 1) exec0.init(blockManager) val securityManager1 = new SecurityManager(conf1) - val exec1 = new NettyBlockTransferService(conf1, securityManager1, numCores = 1) + val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", numCores = 1) exec1.init(blockManager) val result = fetchBlock(exec0, exec1, "1", blockId) match { diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index cc1a9e028708..f3c156e4f709 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -80,7 +80,7 @@ class NettyBlockTransferServiceSuite .set("spark.blockManager.port", port.toString) val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) - val service = new NettyBlockTransferService(conf, securityManager, numCores = 1) + val service = new NettyBlockTransferService(conf, securityManager, "localhost", numCores = 1) service.init(blockDataManager) service } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 43e61241b6cb..cebac2097f38 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -127,9 +127,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val reply = rpcEndpointRef.askWithRetry[String]("hello") @@ -141,9 +140,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) @@ -164,10 +162,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => Thread.sleep(100) context.reply(msg) - } } }) @@ -317,10 +314,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { - case m => { + case m => self callSelfSuccessfully = true - } } }) @@ -682,9 +678,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = localEnv override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2293c11dad73..fd96fb04f8b2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1144,7 +1144,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // SPARK-9809 -- this stage is submitted without a task for each partition (because some of // the shuffle map output is still available from stage 0); make sure we've still got internal // accumulators setup - assert(scheduler.stageIdToStage(2).internalAccumulators.nonEmpty) + assert(scheduler.stageIdToStage(2).latestInfo.internalAccumulators.nonEmpty) completeShuffleMapStageSuccessfully(2, 0, 2) completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) assert(results === Map(0 -> 1234, 1 -> 1235)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index f7e16af9d3a9..e3e6df6831de 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -17,12 +17,14 @@ package org.apache.spark.scheduler +import java.util.Properties + import org.apache.spark.TaskContext class FakeTask( stageId: Int, prefLocs: Seq[TaskLocation] = Nil) - extends Task[Int](stageId, 0, 0, Seq.empty) { + extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 1dca4bd89fd9..76a708764596 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.util.Properties import org.apache.spark.TaskContext @@ -25,7 +26,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index c4cf2f9f7075..86911d2211a3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.scheduler +import java.util.Properties + import org.mockito.Matchers.any import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.executor.TaskMetricsSuite +import org.apache.spark.executor.{Executor, TaskMetricsSuite} import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils @@ -59,7 +61,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) - val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) intercept[RuntimeException] { task.run(0, 0, null) } @@ -79,7 +82,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) - val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) intercept[RuntimeException] { task.run(0, 0, null) } @@ -170,9 +174,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val initialAccums = InternalAccumulator.createAll() // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) { + val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), + new Properties, SparkEnv.get.metricsSystem, initialAccums) context.taskMetrics.registerAccumulator(acc1) @@ -189,6 +194,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) } + test("localProperties are propagated to executors correctly") { + sc = new SparkContext("local", "test") + sc.setLocalProperty("testPropKey", "testPropValue") + val res = sc.parallelize(Array(1), 1).map(i => i).map(i => { + val inTask = TaskContext.get().getLocalProperty("testPropKey") + val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey") + s"$inTask,$inDeser" + }).collect() + assert(res === Array("testPropValue,testPropValue")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 167d3fd2e460..ade8e84d848f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.Map import scala.collection.mutable @@ -138,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 7ee76aa4c6f9..9d1bd7ec89bc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.Properties + import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.implicitConversions import scala.reflect.ClassTag @@ -58,7 +60,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 2ec5319d5571..d26df7e760ce 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -62,7 +62,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) conf.set("spark.memory.offHeap.size", maxMem.toString) - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val store = new BlockManager(name, rpcEnv, master, serializerManager, conf, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 66b28de10f97..a1c2933584ac 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -78,7 +78,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE conf.set("spark.memory.offHeap.size", maxMem.toString) val serializer = new KryoSerializer(conf) val transfer = transferService - .getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1)) + .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1)) val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, @@ -490,7 +490,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) - assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { @@ -852,7 +852,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. conf.set("spark.testing.memory", "1200") - val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) + val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", numCores = 1) val memoryManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 9876bded33a0..7d4c0863bc96 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -322,11 +322,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 207) assert(stage0Data.outputBytes == 116) assert(stage1Data.outputBytes == 208) - assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) - assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 102) - assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 202) // task that was included in a heartbeat @@ -355,9 +355,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 614) assert(stage0Data.outputBytes == 416) assert(stage1Data.outputBytes == 616) - assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) - assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 402) } } diff --git a/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala new file mode 100644 index 000000000000..4a80e3f1f452 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class CausedBySuite extends SparkFunSuite { + + test("For an error without a cause, should return the error") { + val error = new Exception + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === error) + } + + test("For an error with a cause, should return the cause of the error") { + val cause = new Exception + val error = new Exception(cause) + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === cause) + } + + test("For an error with a cause that itself has a cause, return the root cause") { + val causeOfCause = new Exception + val cause = new Exception(causeOfCause) + val error = new Exception(cause) + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === causeOfCause) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 6a2d4c9f2cec..de6f408fa82b 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -853,7 +853,7 @@ private[spark] object JsonProtocolSuite extends Assertions { if (hasHadoopInput) { val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) inputMetrics.setBytesRead(d + e + f) - inputMetrics.incRecordsReadInternal(if (hasRecords) (d + e + f) / 100 else -1) + inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) } else { val sr = t.registerTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 2794b3d235ce..023fba536915 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -20,8 +20,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -123,7 +123,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -136,10 +136,10 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -157,7 +157,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 4906fe9cfae4..003c540d72a0 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -22,8 +22,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -114,7 +114,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -126,11 +126,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -148,7 +148,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 23ff5cfa2ea4..80fbaea22238 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -22,8 +22,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -115,7 +115,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -127,11 +127,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -149,7 +149,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9b5a5643f392..b2c2a4caec86 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -26,8 +26,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -121,7 +121,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -133,11 +133,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -155,7 +155,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1dca2fc55ad3..71e51883d5ab 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -26,8 +26,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -122,7 +122,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -134,11 +134,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -156,7 +156,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bb04ec6ee67d..c844bcff7e4f 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -256,9 +256,21 @@ def __hash__(self): ) +mllib_local = Module( + name="mllib-local", + dependencies=[], + source_file_regexes=[ + "mllib-local", + ], + sbt_test_goals=[ + "mllib-local/test", + ] +) + + mllib = Module( name="mllib", - dependencies=[streaming, sql], + dependencies=[mllib_local, streaming, sql], source_file_regexes=[ "data/mllib/", "mllib/", diff --git a/docs/configuration.md b/docs/configuration.md index 937852ffdecd..16d5be62f9e8 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -225,11 +225,14 @@ Apart from these, the following properties are also available, and may be useful + your default properties file. @@ -269,9 +272,9 @@ Apart from these, the following properties are also available, and may be useful diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 45155c8ad1f1..eaf4f6d84336 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -302,6 +302,40 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe +## Naive Bayes + +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple +probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between the features. The spark.ml implementation currently supports both [multinomial +naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +**Example** + +
+
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details. + +{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %} +
+ +
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %} +
+ +
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details. + +{% include_example python/ml/naive_bayes_example.py %} +
+
+ # Regression diff --git a/docs/ml-features.md b/docs/ml-features.md index 4fe8eefc260d..70812eb5e229 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -149,6 +149,15 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java %} + +
+ +Refer to the [CountVectorizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizer) +and the [CountVectorizerModel Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizerModel) +for more details on the API. + +{% include_example python/ml/count_vectorizer_example.py %} +
# Feature Transformers @@ -413,6 +422,14 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %} + +
+ +Refer to the [DCT Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.DCT) +for more details on the API. + +{% include_example python/ml/dct_example.py %} +
## StringIndexer @@ -771,6 +788,14 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %} + +
+ +Refer to the [MinMaxScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinMaxScaler) +for more details on the API. + +{% include_example python/ml/min_max_scaler_example.py %} +
@@ -803,6 +828,14 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java %} + +
+ +Refer to the [MaxAbsScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MaxAbsScaler) +for more details on the API. + +{% include_example python/ml/max_abs_scaler_example.py %} +
## Bucketizer diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ddc75a70b9d5..09701abdb057 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -342,7 +342,9 @@ If you need a reference to the proper location to put log files in the YARN so t diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 63310be22cbd..2d9849d0328e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1502,7 +1502,7 @@ val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() // root -// |-- age: integer (nullable = true) +// |-- age: long (nullable = true) // |-- name: string (nullable = true) // Register this DataFrame as a table. @@ -1540,7 +1540,7 @@ DataFrame people = sqlContext.read().json("examples/src/main/resources/people.js // The inferred schema can be visualized using the printSchema() method. people.printSchema(); // root -// |-- age: integer (nullable = true) +// |-- age: long (nullable = true) // |-- name: string (nullable = true) // Register this DataFrame as a table. @@ -1578,7 +1578,7 @@ people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() # root -# |-- age: integer (nullable = true) +# |-- age: long (nullable = true) # |-- name: string (nullable = true) # Register this DataFrame as a table. @@ -1617,7 +1617,7 @@ people <- jsonFile(sqlContext, path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) # root -# |-- age: integer (nullable = true) +# |-- age: long (nullable = true) # |-- name: string (nullable = true) # Register this DataFrame as a table. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java deleted file mode 100644 index 07edeb3e521c..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.List; - -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -/** - * A simple example demonstrating model selection using CrossValidator. - * This example also demonstrates how Pipelines are Estimators. - * - * This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and - * {@link org.apache.spark.examples.ml.Document} defined in the Scala example - * {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}. - * - * Run with - *
- * bin/run-example ml.JavaCrossValidatorExample
- * 
- */ -public class JavaCrossValidatorExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // Prepare training documents, which are labeled. - List localTraining = Lists.newArrayList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); - Dataset training = jsql.createDataFrame( - jsc.parallelize(localTraining), LabeledDocument.class); - - // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. - Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); - HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); - LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); - Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - - // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. - // This will allow us to jointly choose parameters for all Pipeline stages. - // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, - // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. - ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - crossval.setEstimatorParamMaps(paramGrid); - crossval.setNumFolds(2); // Use 3+ in practice - - // Run cross-validation, and choose the best set of parameters. - CrossValidatorModel cvModel = crossval.fit(training); - - // Prepare test documents, which are unlabeled. - List localTest = Lists.newArrayList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); - Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); - - // Make predictions on test documents. cvModel uses the best model found (lrModel). - Dataset predictions = cvModel.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); - } - - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index fbd881766983..0ba94786d4e5 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -146,7 +146,7 @@ MyJavaLogisticRegression setMaxIter(int value) { // This method is used by fit(). // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(Dataset dataset) { + public MyJavaLogisticRegressionModel train(Dataset dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java index 6ac4aea3c483..4994f8f9fa85 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -32,7 +32,15 @@ import org.apache.spark.sql.SQLContext; /** - * Java example for Model Selection via Train Validation Split. + * Java example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaModelSelectionViaTrainValidationSplitExample + * }}} */ public class JavaModelSelectionViaTrainValidationSplitExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java new file mode 100644 index 000000000000..41d7ad75b9d4 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.NaiveBayes; +import org.apache.spark.ml.classification.NaiveBayesModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +/** + * An example for Naive Bayes Classification. + */ +public class JavaNaiveBayesExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + // Load training data + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + // Split the data into train and test + Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); + Dataset train = splits[0]; + Dataset test = splits[1]; + + // create the trainer and set its parameters + NaiveBayes nb = new NaiveBayes(); + // train the model + NaiveBayesModel model = nb.fit(train); + // compute precision on the test set + Dataset result = model.transform(test); + Dataset predictionAndLabels = result.select("prediction", "label"); + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); + System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java deleted file mode 100644 index 09bbc39c01fe..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -/** - * A simple example demonstrating model selection using TrainValidationSplit. - * - * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} - * using linear regression. - * - * Run with - * {{{ - * bin/run-example ml.JavaTrainValidationSplitExample - * }}} - */ -public class JavaTrainValidationSplitExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - Dataset data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - // Prepare training and test data. - Dataset[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); - Dataset training = splits[0]; - Dataset test = splits[1]; - - LinearRegression lr = new LinearRegression(); - - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // TrainValidationSplit will try all combinations of values and determine best model using - // the evaluator. - ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - - // In this case the estimator is simply the linear regression. - // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid); - - // 80% of the data will be used for training and the remaining 20% for validation. - trainValidationSplit.setTrainRatio(0.8); - - // Run train validation split, and choose the best set of parameters. - TrainValidationSplitModel model = trainValidationSplit.fit(training); - - // Make predictions on test data. model is the model with combination of parameters - // that performed best. - model.transform(test) - .select("features", "label", "prediction") - .show(); - - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java deleted file mode 100644 index 36baf5868736..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.ArrayList; - -import com.google.common.base.Joiner; -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.fpm.FPGrowth; -import org.apache.spark.mllib.fpm.FPGrowthModel; - -/** - * Java example for mining frequent itemsets using FP-growth. - * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt - */ -public class JavaFPGrowthExample { - - public static void main(String[] args) { - String inputFile; - double minSupport = 0.3; - int numPartition = -1; - if (args.length < 1) { - System.err.println( - "Usage: JavaFPGrowth [minSupport] [numPartition]"); - System.exit(1); - } - inputFile = args[0]; - if (args.length >= 2) { - minSupport = Double.parseDouble(args[1]); - } - if (args.length >= 3) { - numPartition = Integer.parseInt(args[2]); - } - - SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD> transactions = sc.textFile(inputFile).map( - new Function>() { - @Override - public ArrayList call(String s) { - return Lists.newArrayList(s.split(" ")); - } - } - ); - - FPGrowthModel model = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartition) - .run(transactions); - - for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java deleted file mode 100644 index e575eedeb465..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.mllib.clustering.KMeans; -import org.apache.spark.mllib.clustering.KMeansModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -/** - * Example using MLlib KMeans from Java. - */ -public final class JavaKMeans { - - private static class ParsePoint implements Function { - private static final Pattern SPACE = Pattern.compile(" "); - - @Override - public Vector call(String line) { - String[] tok = SPACE.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - return Vectors.dense(point); - } - } - - public static void main(String[] args) { - if (args.length < 3) { - System.err.println( - "Usage: JavaKMeans []"); - System.exit(1); - } - String inputFile = args[0]; - int k = Integer.parseInt(args[1]); - int iterations = Integer.parseInt(args[2]); - int runs = 1; - - if (args.length >= 4) { - runs = Integer.parseInt(args[3]); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(inputFile); - - JavaRDD points = lines.map(new ParsePoint()); - - KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL()); - - System.out.println("Cluster centers:"); - for (Vector center : model.clusterCenters()) { - System.out.println(" " + center); - } - double cost = model.computeCost(points.rdd()); - System.out.println("Cost: " + cost); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java index 006d96d11196..2d89c768fcfc 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java @@ -58,6 +58,13 @@ public Vector call(String s) { int numIterations = 20; KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); + System.out.println("Cluster centers:"); + for (Vector center: clusters.clusterCenters()) { + System.out.println(" " + center); + } + double cost = clusters.computeCost(parsedData.rdd()); + System.out.println("Cost: " + cost); + // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java deleted file mode 100644 index de8e739ac925..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.DistributedLDAModel; -import org.apache.spark.mllib.clustering.LDA; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; - -public class JavaLDAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("LDA Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_lda_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) { - values[i] = Double.parseDouble(sarray[i]); - } - return Vectors.dense(values); - } - } - ); - // Index documents with unique IDs - JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 doc_id) { - return doc_id.swap(); - } - } - )); - corpus.cache(); - - // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); - - // Output topics. Each is a distribution over words (matching word count vectors) - System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() - + " words):"); - Matrix topics = ldaModel.topicsMatrix(); - for (int topic = 0; topic < 3; topic++) { - System.out.print("Topic " + topic + ":"); - for (int word = 0; word < ldaModel.vocabSize(); word++) { - System.out.print(" " + topics.apply(word, topic)); - } - System.out.println(); - } - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java deleted file mode 100644 index eceb6927d555..000000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; - -/** - * Logistic regression based classification using ML Lib. - */ -public final class JavaLR { - - static class ParsePoint implements Function { - private static final Pattern COMMA = Pattern.compile(","); - private static final Pattern SPACE = Pattern.compile(" "); - - @Override - public LabeledPoint call(String line) { - String[] parts = COMMA.split(line); - double y = Double.parseDouble(parts[0]); - String[] tok = SPACE.split(parts[1]); - double[] x = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - x[i] = Double.parseDouble(tok[i]); - } - return new LabeledPoint(y, Vectors.dense(x)); - } - } - - public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaLR "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(args[0]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[1]); - int iterations = Integer.parseInt(args[2]); - - // Another way to configure LogisticRegression - // - // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); - // lr.optimizer().setNumIterations(iterations) - // .setStepSize(stepSize) - // .setMiniBatchFraction(1.0); - // lr.setIntercept(true); - // LogisticRegressionModel model = lr.train(points.rdd()); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); - - System.out.print("Final w: " + model.weights()); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java index 86c389e11cfd..72bbb2a8fa46 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java @@ -35,6 +35,7 @@ public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("JavaStratifiedSamplingExample"); JavaSparkContext jsc = new JavaSparkContext(conf); + @SuppressWarnings("unchecked") // $example on$ List> list = new ArrayList<>( Arrays.>asList( diff --git a/examples/src/main/python/ml/count_vectorizer_example.py b/examples/src/main/python/ml/count_vectorizer_example.py new file mode 100644 index 000000000000..e839f645f70b --- /dev/null +++ b/examples/src/main/python/ml/count_vectorizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import CountVectorizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="CountVectorizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Input data: Each row is a bag of words with a ID. + df = sqlContext.createDataFrame([ + (0, "a b c".split(" ")), + (1, "a b b c a".split(" ")) + ], ["id", "words"]) + + # fit a CountVectorizerModel from the corpus. + cv = CountVectorizer(inputCol="words", outputCol="features", vocabSize=3, minDF=2.0) + model = cv.fit(df) + result = model.transform(df) + result.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/dct_example.py b/examples/src/main/python/ml/dct_example.py new file mode 100644 index 000000000000..264d47f404cb --- /dev/null +++ b/examples/src/main/python/ml/dct_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import DCT +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="DCTExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (Vectors.dense([0.0, 1.0, -2.0, 3.0]),), + (Vectors.dense([-1.0, 2.0, 4.0, -7.0]),), + (Vectors.dense([14.0, -2.0, -5.0, 1.0]),)], ["features"]) + + dct = DCT(inverse=False, inputCol="features", outputCol="featuresDCT") + + dctDf = dct.transform(df) + + for dcts in dctDf.select("featuresDCT").take(3): + print(dcts) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/max_abs_scaler_example.py b/examples/src/main/python/ml/max_abs_scaler_example.py new file mode 100644 index 000000000000..d9b69eef1cd8 --- /dev/null +++ b/examples/src/main/python/ml/max_abs_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import MaxAbsScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="MaxAbsScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + scaler = MaxAbsScaler(inputCol="features", outputCol="scaledFeatures") + + # Compute summary statistics and generate MaxAbsScalerModel + scalerModel = scaler.fit(dataFrame) + + # rescale each feature to range [-1, 1]. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/min_max_scaler_example.py b/examples/src/main/python/ml/min_max_scaler_example.py new file mode 100644 index 000000000000..2f8e4ade468b --- /dev/null +++ b/examples/src/main/python/ml/min_max_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import MinMaxScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="MinMaxScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures") + + # Compute summary statistics and generate MinMaxScalerModel + scalerModel = scaler.fit(dataFrame) + + # rescale each feature to range [min, max]. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py new file mode 100644 index 000000000000..db8fbea9bf9b --- /dev/null +++ b/examples/src/main/python/ml/naive_bayes_example.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import NaiveBayes +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="naive_bayes_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + data = sqlContext.read.format("libsvm") \ + .load("data/mllib/sample_libsvm_data.txt") + # Split the data into train and test + splits = data.randomSplit([0.6, 0.4], 1234) + train = splits[0] + test = splits[1] + + # create the trainer and set its parameters + nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + + # train the model + model = nb.fit(train) + # compute precision on the test set + result = model.transform(test) + predictionAndLabels = result.select("prediction", "label") + evaluator = MulticlassClassificationEvaluator(metricName="precision") + print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 973b005f91f6..ca4eea235683 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -106,9 +106,8 @@ object CassandraCQLTest { println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { - case (key, value) => { + case (key, value) => (ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity"))) - } } val aggregatedRDD = productSaleRDD.reduceByKey(_ + _) aggregatedRDD.collect().foreach { @@ -116,11 +115,10 @@ object CassandraCQLTest { } val casoutputCF = aggregatedRDD.map { - case (productId, saleCount) => { + case (productId, saleCount) => val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) (outKey, outVal) - } } casoutputCF.saveAsNewAPIHadoopFile( diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 6a8f73ad000f..eff840d36e8d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -90,9 +90,8 @@ object CassandraTest { // Let us first get all the paragraphs from the retrieved rows val paraRdd = casRdd.map { - case (key, value) => { + case (key, value) => ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) - } } // Lets get the word count in paras @@ -103,7 +102,7 @@ object CassandraTest { } counts.map { - case (word, count) => { + case (word, count) => val colWord = new org.apache.cassandra.thrift.Column() colWord.setName(ByteBufferUtil.bytes("word")) colWord.setValue(ByteBufferUtil.bytes(word)) @@ -122,7 +121,6 @@ object CassandraTest { mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(1).column_or_supercolumn.setColumn(colCount) (outputkey, mutations) - } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index af5f216f28ba..fa1010195551 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -104,16 +104,14 @@ object LocalALS { def main(args: Array[String]) { args match { - case Array(m, u, f, iters) => { + case Array(m, u, f, iters) => M = m.toInt U = u.toInt F = f.toInt ITERATIONS = iters.toInt - } - case _ => { + case _ => System.err.println("Usage: LocalALS ") System.exit(1) - } } showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index c1f63c6a1dce..8d127f9b3542 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String) def setMaxIter(value: Int): this.type = set(maxIter, value) // This method is used by fit() - override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { + override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = { // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index af90652b55a1..7af011571f76 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -23,8 +23,8 @@ import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.clustering.KMeans import org.apache.spark.mllib.linalg.Vectors -// $example off$ import org.apache.spark.sql.{DataFrame, SQLContext} +// $example off$ /** * An example demonstrating a k-means clustering. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala new file mode 100644 index 000000000000..5ea1270c9781 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.{NaiveBayes} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ +import org.apache.spark.sql.SQLContext + +object NaiveBayesExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NaiveBayesExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a NaiveBayes model. + val model = new NaiveBayes() + .fit(trainingData) + + // Select example rows to display. + val predictions = model.transform(testData) + predictions.show() + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("precision") + val precision = evaluator.evaluate(predictions) + println("Precision:" + precision) + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index a0bb5dabf457..0b5d31c0ff90 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -118,17 +118,15 @@ object OneVsRestExample { val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { - case Some(t) => { + case Some(t) => // compute the number of features in the training set. val numFeatures = inputData.first().getAs[Vector](1).size val testData = sqlContext.read.option("numFeatures", numFeatures.toString) .format("libsvm").load(t) Array[DataFrame](inputData, testData) - } - case None => { + case None => val f = params.fracTest inputData.randomSplit(Array(1 - f, f), seed = 12345) - } } val Array(train, test) = data.map(_.cache()) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index c263f4f595a3..ee811d3aa101 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -180,7 +180,7 @@ object DecisionTreeRunner { } // For classification, re-index classes if needed. val (examples, classIndexMap, numClasses) = algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() val sortedClasses = classCounts.keys.toList.sorted @@ -209,7 +209,6 @@ object DecisionTreeRunner { println(s"$c\t$frac\t${classCounts(c)}") } (examples, classIndexMap, numClasses) - } case Regression => (origExamples, null, 0) case _ => @@ -225,7 +224,7 @@ object DecisionTreeRunner { case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val testExamples = { if (classIndexMap.isEmpty) { @@ -235,7 +234,6 @@ object DecisionTreeRunner { } } Array(examples, testExamples) - } case Regression => Array(examples, origTestExamples) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index b6b8bc33f7e1..bb2af9cd72e2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -116,7 +116,7 @@ object RecoverableNetworkWordCount { val lines = ssc.socketTextStream(ip, port) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => // Get or register the blacklist Broadcast val blacklist = WordBlacklist.getInstance(rdd.sparkContext) // Get or register the droppedWordsCounter Accumulator @@ -135,7 +135,7 @@ object RecoverableNetworkWordCount { println("Dropped " + droppedWordsCounter.value + " word(s) totally") println("Appending to " + outputFile.getAbsolutePath) Files.append(output + "\n", outputFile, Charset.defaultCharset()) - }) + } ssc } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 3727f8fe6a21..918e124065e4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -59,7 +59,7 @@ object SqlNetworkWordCount { val words = lines.flatMap(_.split(" ")) // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD((rdd: RDD[String], time: Time) => { + words.foreachRDD { (rdd: RDD[String], time: Time) => // Get the singleton instance of SQLContext val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) import sqlContext.implicits._ @@ -75,7 +75,7 @@ object SqlNetworkWordCount { sqlContext.sql("select word, count(*) as total from words group by word") println(s"========= $time =========") wordCountsDataFrame.show() - }) + } ssc.start() ssc.awaitTermination() diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 1764aa9465c4..17fd7d781c9a 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -34,6 +34,13 @@ docker-integration-tests + + + db2 + https://app.camunda.com/nexus/content/repositories/public/ + + + com.spotify @@ -180,5 +187,28 @@ + + + + com.ibm.db2.jcc + db2jcc4 + 10.5.0.5 + jar + diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala new file mode 100644 index 000000000000..4fe1ef669720 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.scalatest._ + +import org.apache.spark.tags.DockerTest + +@DockerTest +@Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker +class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "lresende/db2express-c:10.5.0.5-3.10.0" + override val env = Map( + "DB2INST1_PASSWORD" -> "rootpass", + "LICENSE" -> "accept" + ) + override val usesIpc = true + override val jdbcPort: Int = 50000 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:db2://$ip:$port/foo:user=db2inst1;password=rootpass;" + override def getStartupProcessName: Option[String] = Some("db2start") + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y VARCHAR(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, " + + "e CHAR FOR BIT DATA)").executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps'") + .executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index f73231fc80a0..c36f4d5f9548 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -44,6 +44,11 @@ abstract class DatabaseOnDocker { */ val env: Map[String, String] + /** + * Wheather or not to use ipc mode for shared memory when starting docker image + */ + val usesIpc: Boolean + /** * The container-internal JDBC port that the database listens on. */ @@ -53,6 +58,11 @@ abstract class DatabaseOnDocker { * Return a JDBC URL that connects to the database running at the given IP address and port. */ def getJdbcUrl(ip: String, port: Int): String + + /** + * Optional process to run when container starts + */ + def getStartupProcessName: Option[String] } abstract class DockerJDBCIntegrationSuite @@ -97,17 +107,23 @@ abstract class DockerJDBCIntegrationSuite val dockerIp = DockerUtils.getDockerIp() val hostConfig: HostConfig = HostConfig.builder() .networkMode("bridge") + .ipcMode(if (db.usesIpc) "host" else "") .portBindings( Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) .build() // Create the database container: - val config = ContainerConfig.builder() + val containerConfigBuilder = ContainerConfig.builder() .image(db.imageName) .networkDisabled(false) .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) .hostConfig(hostConfig) .exposedPorts(s"${db.jdbcPort}/tcp") - .build() + if(db.getStartupProcessName.isDefined) { + containerConfigBuilder + .cmd(db.getStartupProcessName.get) + } + val config = containerConfigBuilder.build() + // Create the database container: containerId = docker.createContainer(config).id // Start the container and wait until the database can accept JDBC connections: docker.startContainer(containerId) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index c68e4dc4933b..a70ed98b52d5 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -30,9 +30,11 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) + override val usesIpc = false override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 8a0f938f7e3b..2fc174eb1b3a 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -52,9 +52,11 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) + override val usesIpc = false override val jdbcPort: Int = 1521 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:oracle:thin:system/oracle@//$ip:$port/xe" + override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index d55cdcf28b23..79dd70116ecb 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -32,9 +32,11 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) + override val usesIpc = false override val jdbcPort = 5432 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 719fca0938b3..8050ec357e26 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -129,9 +129,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha * @param success Whether the batch was successful or not. */ private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { - removeAndGetProcessor(sequenceNumber).foreach(processor => { + removeAndGetProcessor(sequenceNumber).foreach { processor => processor.batchProcessed(success) - }) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 14dffb15fef9..41f27e937662 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -88,23 +88,23 @@ class SparkSink extends AbstractSink with Logging with Configurable { // dependencies which are being excluded in the build. In practice, // Netty dependencies are already available on the JVM as Flume would have pulled them in. serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) - serverOpt.foreach(server => { + serverOpt.foreach { server => logInfo("Starting Avro server for sink: " + getName) server.start() - }) + } super.start() } override def stop() { logInfo("Stopping Spark Sink: " + getName) - handler.foreach(callbackHandler => { + handler.foreach { callbackHandler => callbackHandler.shutdown() - }) - serverOpt.foreach(server => { + } + serverOpt.foreach { server => logInfo("Stopping Avro Server for sink: " + getName) server.close() server.join() - }) + } blockingLatch.countDown() super.stop() } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index b15c2097e550..19e736f01697 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -110,7 +110,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg("Something went wrong. Channel was " + "unable to create a transaction!") } - txOpt.foreach(tx => { + txOpt.foreach { tx => tx.begin() val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) val loop = new Breaks @@ -145,7 +145,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // At this point, the events are available, so fill them into the event batch eventBatch = new EventBatch("", seqNum, events) } - }) + } } catch { case interrupted: InterruptedException => // Don't pollute logs if the InterruptedException came from this being stopped @@ -156,9 +156,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, logWarning("Error while processing transaction.", e) eventBatch.setErrorMsg(e.getMessage) try { - txOpt.foreach(tx => { + txOpt.foreach { tx => rollbackAndClose(tx, close = true) - }) + } } finally { txOpt = None } @@ -174,7 +174,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, */ private def processAckOrNack() { batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) - txOpt.foreach(tx => { + txOpt.foreach { tx => if (batchSuccess) { try { logDebug("Committing transaction") @@ -197,7 +197,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // cause issues. This is required to ensure the TransactionProcessor instance is not leaked parent.removeAndGetProcessor(seqNum) } - }) + } } /** diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 250bfc1718db..54565840fa66 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -79,11 +79,11 @@ private[streaming] class FlumePollingReceiver( override def onStart(): Unit = { // Create the connections to each Flume agent. - addresses.foreach(host => { + addresses.foreach { host => val transceiver = new NettyTransceiver(host, channelFactory) val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) connections.add(new FlumeConnection(transceiver, client)) - }) + } for (i <- 0 until parallelism) { logInfo("Starting Flume Polling Receiver worker threads..") // Threads that pull data from Flume. diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index 1a96df6e94b9..6a4dafb8eddb 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -123,9 +123,9 @@ private[flume] class PollingFlumeTestUtils { val latch = new CountDownLatch(batchCount * channels.size) sinks.foreach(_.countdownWhenBatchReceived(latch)) - channels.foreach(channel => { + channels.foreach { channel => executorCompletion.submit(new TxnSubmitter(channel)) - }) + } for (i <- 0 until channels.size) { executorCompletion.take() diff --git a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 67bc64a44466..d0fed303e659 100644 --- a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -377,7 +377,9 @@ public void testForeachRDD() { }); // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java - stream.foreachRDD((rdd, time) -> { return; }); + stream.foreachRDD((rdd, time) -> { + return; + }); JavaTestUtils.runStreams(ssc, 2, 2); @@ -873,7 +875,7 @@ public void testMapWithStateAPI() { JavaMapWithStateDStream stateDstream = wordsDstream.mapWithState( - StateSpec. function((time, key, value, state) -> { + StateSpec.function((time, key, value, state) -> { // Use all State's methods here state.exists(); state.get(); diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 41c6ab123bae..80e0cce05586 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -73,7 +73,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") receiver.setCheckpointer(shardId, checkpointer) } catch { - case NonFatal(e) => { + case NonFatal(e) => /* * If there is a failure within the batch, the batch will not be checkpointed. * This will potentially cause records since the last checkpoint to be processed @@ -84,7 +84,6 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e - } } } else { /* RecordProcessor has been stopped. */ @@ -148,29 +147,25 @@ private[kinesis] object KinesisRecordProcessor extends Logging { /* If the function failed, either retry or throw the exception */ case util.Failure(e) => e match { /* Retry: Throttling or other Retryable exception has occurred */ - case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 - => { - val backOffMillis = Random.nextInt(maxBackOffMillis) - Thread.sleep(backOffMillis) - logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) - retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) - } + case _: ThrottlingException | _: KinesisClientLibDependencyException + if numRetriesLeft > 1 => + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) /* Throw: Shutdown has been requested by the Kinesis Client Library. */ - case _: ShutdownException => { + case _: ShutdownException => logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e - } /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ - case _: InvalidStateException => { + case _: InvalidStateException => logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) throw e - } /* Throw: Unexpected exception has occurred */ - case _ => { + case _ => logError(s"Unexpected, non-retryable exception.", e) throw e - } } } } diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml new file mode 100644 index 000000000000..68f15dd90502 --- /dev/null +++ b/mllib-local/pom.xml @@ -0,0 +1,74 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-mllib-local_2.11 + + mllib-local + + jar + Spark Project ML Local Library + http://spark.apache.org/ + + + + org.scalanlp + breeze_${scala.binary.version} + + + org.apache.commons + commons-math3 + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + + + netlib-lgpl + + + com.github.fommil.netlib + all + ${netlib.java.version} + pom + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala new file mode 100644 index 000000000000..6b3268cdfa25 --- /dev/null +++ b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +// This is a private class testing if the new build works. To be removed soon. +private[ml] object DummyTesting { + private[ml] def add10(input: Double): Double = input + 10 +} diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala new file mode 100644 index 000000000000..51b7c2409ff2 --- /dev/null +++ b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +// This is testing if the new build works. To be removed soon. +class DummyTestingSuite extends FunSuite { // scalastyle:ignore funsuite + + test("This is testing if the new build works.") { + assert(DummyTesting.add10(15) === 25) + } +} diff --git a/mllib/pom.xml b/mllib/pom.xml index 428176dcbfad..24d8274e2222 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -62,22 +62,21 @@ spark-graphx_${scala.binary.version} ${project.version} + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + test-jar + test + org.scalanlp breeze_${scala.binary.version} - 0.11.2 - - - - junit - junit - - - org.apache.commons - commons-math3 - - org.apache.commons diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 57e416591de6..1247882d6c1b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * Estimator's embedded ParamMap. * @return fitted model */ + @Since("2.0.0") @varargs - def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { val map = new ParamMap() .put(firstParamPair) .put(otherParamPairs: _*) @@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMap: ParamMap): M = { copy(paramMap).fit(dataset) } /** * Fits a model to the input data. */ - def fit(dataset: DataFrame): M + @Since("2.0.0") + def fit(dataset: Dataset[_]): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index afefaaa8832c..82066726a069 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") ( * @param dataset input dataset * @return fitted pipeline */ - @Since("1.2.0") - override def fit(dataset: DataFrame): PipelineModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -291,10 +291,10 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } - @Since("1.2.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) + stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d23ae6f794d7..81140d1f7b21 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -83,7 +83,7 @@ abstract class Predictor[ /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - override def fit(dataset: DataFrame): M = { + override def fit(dataset: Dataset[_]): M = { // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) @@ -100,7 +100,7 @@ abstract class Predictor[ * @param dataset Training dataset * @return Fitted model */ - protected def train(dataset: DataFrame): M + protected def train(dataset: Dataset[_]): M /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. @@ -120,7 +120,7 @@ abstract class Predictor[ * Extract [[labelCol]] and [[featuresCol]] from the given dataset, * and put it in an RDD with strong types. */ - protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { + protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -171,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @param dataset input dataset * @return transformed dataset with [[predictionCol]] of type [[Double]] */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") - dataset + dataset.toDF } } - protected def transformImpl(dataset: DataFrame): DataFrame = { + protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 2538c0f477fc..a3a2b55adc25 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage { * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ + @Since("2.0.0") @varargs def transform( - dataset: DataFrame, + dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() @@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + @Since("2.0.0") + def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = { this.copy(paramMap).transform(dataset) } /** * Transforms the input dataset. */ - def transform(dataset: DataFrame): DataFrame + @Since("2.0.0") + def transform(dataset: Dataset[_]): DataFrame override def copy(extra: ParamMap): Transformer } @@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 8186afc17a53..473e801794c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Output selected columns only. @@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 4525bf71f69e..300ae4339c3c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** @@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) - override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index bee90fb3a568..39a698af153b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -18,22 +18,24 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -42,13 +44,23 @@ import org.apache.spark.sql.functions._ * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. * Note: Multiclass labels are not currently supported. + * + * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - We expect to implement TreeBoost in the future: + * [https://issues.apache.org/jira/browse/SPARK-4240] */ @Since("1.4.0") @Experimental final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] - with GBTParams with TreeClassifierParams with Logging { + with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) @@ -105,41 +117,13 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTClassifier: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "logistic" - * (default = logistic) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "logistic") + // Parameters from GBTClassifierParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "logistic" => OldLogLoss - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") - } - } - - override protected def train(dataset: DataFrame): GBTClassificationModel = { + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -165,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental -object GBTClassifier { - // The losses below should be lowercase. +object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { + /** Accessor for supported loss settings: logistic */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTClassifier = super.load(path) } /** @@ -189,7 +176,8 @@ final class GBTClassificationModel private[ml]( private val _treeWeights: Array[Double], @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + @@ -210,7 +198,7 @@ final class GBTClassificationModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -257,12 +245,62 @@ final class GBTClassificationModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } -private[ml] object GBTClassificationModel { +@Since("2.0.0") +object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader + + @Since("2.0.0") + override def load(path: String): GBTClassificationModel = super.load(path) + + private[GBTClassificationModel] + class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTClassificationModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 37182928cccc..c2b440059b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -257,12 +257,12 @@ class LogisticRegression @Since("1.2.0") ( this } - override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE train(dataset, handlePersistence) } - protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): + protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = @@ -273,6 +273,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val instr = Instrumentation.create(this, instances) + instr.logParams(regParam, elasticNetParam, standardization, threshold, + maxIter, tol, fitIntercept) + val (summarizer, labelSummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), instance: Instance) => @@ -291,6 +295,9 @@ class LogisticRegression @Since("1.2.0") ( val numClasses = histogram.length val numFeatures = summarizer.mean.size + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) + val (coefficients, intercept, objectiveHistory) = { if (numInvalid != 0) { val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + @@ -444,7 +451,9 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + val m = model.setSummary(logRegSummary) + instr.logSuccess(m) + m } @Since("1.4.0") @@ -544,7 +553,7 @@ class LogisticRegressionModel private[spark] ( * @param dataset Test dataset to evaluate model on. */ @Since("2.0.0") - def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), @@ -774,10 +783,10 @@ sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary */ sealed trait LogisticRegressionSummary extends Serializable { - /** Dataframe outputted by the model's `transform` method. */ + /** Dataframe output by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each class as a vector. */ + /** Field in "predictions" which gives the probability of each class as a vector. */ def probabilityCol: String /** Field in "predictions" which gives the true label of each instance (if available). */ @@ -792,8 +801,8 @@ sealed trait LogisticRegressionSummary extends Serializable { * :: Experimental :: * Logistic regression training results. * - * @param predictions dataframe outputted by the model's `transform` method. - * @param probabilityCol field in "predictions" which gives the calibrated probability of + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. @@ -816,8 +825,8 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * :: Experimental :: * Binary Logistic regression results for a given model. * - * @param predictions dataframe outputted by the model's `transform` method. - * @param probabilityCol field in "predictions" which gives the calibrated probability of + * @param predictions dataframe output by the model's `transform` method. + * @param probabilityCol field in "predictions" which gives the probability of * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 79bb2a8855dc..9ff5252e4ff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTo import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** Params for Multilayer Perceptron. */ private[ml] trait MultilayerPerceptronParams extends PredictorParams @@ -199,7 +199,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { + override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 483ef0d88ca6..267d63b51eb6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** * Params for Naive Bayes Classifiers. @@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") ( def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) - override protected def train(dataset: DataFrame): NaiveBayesModel = { + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 263d54ce4d7e..4de1b877b019 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -140,8 +140,8 @@ final class OneVsRestModel private[ml] ( validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -293,8 +293,8 @@ final class OneVsRest @Since("1.4.0") ( validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } - @Since("1.4.0") - override def fit(dataset: DataFrame): OneVsRestModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): OneVsRestModel = { transformSchema(dataset.schema) // determine number of classes either from metadata if provided, or via computation. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 865614aa5c8a..d00fee12b08c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[ * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index cb42532271a8..dfa711b2436c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -98,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestClassificationModel = { + override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -180,7 +180,7 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -294,7 +294,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 55f751c57f3e..6cc9117da3fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( * centers. */ @Since("2.0.0") - def computeCost(dataset: DataFrame): Double = { + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") ( def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) @Since("2.0.0") - override def fit(dataset: DataFrame): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 120bf3cf9d31..ead8ad780629 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -80,7 +80,7 @@ class GaussianMixtureModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) @@ -238,7 +238,7 @@ class GaussianMixture @Since("2.0.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: DataFrame): GaussianMixtureModel = { + override def fit(dataset: Dataset[_]): GaussianMixtureModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibGM() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a8beef8b120e..b324196842a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -105,8 +105,8 @@ class KMeansModel private[ml] ( copyValues(copied, extra) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -126,8 +126,8 @@ class KMeansModel private[ml] ( * model on the given data. */ // TODO: Replace the temp fix when we have proper evaluators defined for clustering. - @Since("1.6.0") - def computeCost(dataset: DataFrame): Double = { + @Since("2.0.0") + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -143,6 +143,12 @@ class KMeansModel private[ml] ( this } + /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + /** * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. @@ -254,8 +260,8 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): KMeansModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibKMeans() @@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) - val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + val summary = new KMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(summary) } @@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +/** + * :: Experimental :: + * Summary of KMeans. + * + * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental class KMeansSummary private[clustering] ( @Since("2.0.0") @transient val predictions: DataFrame, @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val featuresCol: String) extends Serializable { + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { /** * Cluster centers of the transformed data. @@ -296,11 +315,15 @@ class KMeansSummary private[clustering] ( @transient lazy val cluster: DataFrame = predictions.select(predictionCol) /** - * Size of each cluster. + * Size of (number of data points in) each cluster. */ @Since("2.0.0") - lazy val clusterSizes: Array[Int] = cluster.rdd.map { - case Row(clusterIdx: Int) => (clusterIdx, 1) - }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 727b72470838..c57ceba4a997 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType @@ -190,6 +190,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM def getTopicDistributionCol: String = $(topicDistributionCol) /** + * For Online optimizer only: [[optimizer]] = "online". + * * A (positive) learning parameter that downweights early iterations. Larger values make early * iterations count less. * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) @@ -198,8 +200,9 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @group expertParam */ @Since("1.6.0") - final val learningOffset = new DoubleParam(this, "learningOffset", "A (positive) learning" + - " parameter that downweights early iterations. Larger values make early iterations count less.", + final val learningOffset = new DoubleParam(this, "learningOffset", "(For online optimizer)" + + " A (positive) learning parameter that downweights early iterations. Larger values make early" + + " iterations count less.", ParamValidators.gt(0)) /** @group expertGetParam */ @@ -207,6 +210,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM def getLearningOffset: Double = $(learningOffset) /** + * For Online optimizer only: [[optimizer]] = "online". + * * Learning rate, set as an exponential decay rate. * This should be between (0.5, 1.0] to guarantee asymptotic convergence. * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). @@ -215,15 +220,17 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @group expertParam */ @Since("1.6.0") - final val learningDecay = new DoubleParam(this, "learningDecay", "Learning rate, set as an" + - " exponential decay rate. This should be between (0.5, 1.0] to guarantee asymptotic" + - " convergence.", ParamValidators.gt(0)) + final val learningDecay = new DoubleParam(this, "learningDecay", "(For online optimizer)" + + " Learning rate, set as an exponential decay rate. This should be between (0.5, 1.0] to" + + " guarantee asymptotic convergence.", ParamValidators.gt(0)) /** @group expertGetParam */ @Since("1.6.0") def getLearningDecay: Double = $(learningDecay) /** + * For Online optimizer only: [[optimizer]] = "online". + * * Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, * in range (0, 1]. * @@ -239,8 +246,9 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @group param */ @Since("1.6.0") - final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "Fraction of the corpus" + - " to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1].", + final val subsamplingRate = new DoubleParam(this, "subsamplingRate", "(For online optimizer)" + + " Fraction of the corpus to be sampled and used in each iteration of mini-batch" + + " gradient descent, in range (0, 1].", ParamValidators.inRange(0.0, 1.0, lowerInclusive = false, upperInclusive = true)) /** @group getParam */ @@ -248,6 +256,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM def getSubsamplingRate: Double = $(subsamplingRate) /** + * For Online optimizer only (currently): [[optimizer]] = "online". + * * Indicates whether the docConcentration (Dirichlet parameter for * document-topic distribution) will be optimized during training. * Setting this to true will make the model more expressive and fit the training data better. @@ -257,15 +267,17 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM */ @Since("1.6.0") final val optimizeDocConcentration = new BooleanParam(this, "optimizeDocConcentration", - "Indicates whether the docConcentration (Dirichlet parameter for document-topic" + - " distribution) will be optimized during training.") + "(For online optimizer only, currently) Indicates whether the docConcentration" + + " (Dirichlet parameter for document-topic distribution) will be optimized during training.") /** @group expertGetParam */ @Since("1.6.0") def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) /** - * For EM optimizer, if using checkpointing, this indicates whether to keep the last + * For EM optimizer only: [[optimizer]] = "em". + * + * If using checkpointing, this indicates whether to keep the last * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can * cause failures if a data partition is lost, so set this bit with care. * Note that checkpoints will be cleaned up via reference counting, regardless. @@ -279,7 +291,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM */ @Since("2.0.0") final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint", - "For EM optimizer, if using checkpointing, this indicates whether to keep the last" + + "(For EM optimizer) If using checkpointing, this indicates whether to keep the last" + " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" + " cause failures if a data partition is lost, so set this bit with care.") @@ -390,15 +402,15 @@ sealed abstract class LDAModel private[ml] ( * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. * This implementation may be changed in the future. */ - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if ($(topicDistributionCol).nonEmpty) { val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset + dataset.toDF } } @@ -443,8 +455,8 @@ sealed abstract class LDAModel private[ml] ( * @param dataset test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus */ - @Since("1.6.0") - def logLikelihood(dataset: DataFrame): Double = { + @Since("2.0.0") + def logLikelihood(dataset: Dataset[_]): Double = { val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logLikelihood(oldDataset) } @@ -460,8 +472,8 @@ sealed abstract class LDAModel private[ml] ( * @param dataset test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. */ - @Since("1.6.0") - def logPerplexity(dataset: DataFrame): Double = { + @Since("2.0.0") + def logPerplexity(dataset: Dataset[_]): Double = { val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logPerplexity(oldDataset) } @@ -828,8 +840,8 @@ class LDA @Since("1.6.0") ( @Since("1.6.0") override def copy(extra: ParamMap): LDA = defaultCopy(extra) - @Since("1.6.0") - override def fit(dataset: DataFrame): LDAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): LDAModel = { transformSchema(dataset.schema, logging = true) val oldLDA = new OldLDA() .setK($(k)) @@ -861,7 +873,7 @@ class LDA @Since("1.6.0") ( private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ - def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { + def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = { dataset .withColumn("docId", monotonicallyIncreasingId()) .select("docId", featuresCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 337ffbe90f36..bde8c275fda4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va setDefault(metricName -> "areaUnderROC") - @Since("1.2.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 0f22cca3a78d..5f765c071b9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, Params} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -36,8 +36,8 @@ abstract class Evaluator extends Params { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + @Since("2.0.0") + def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = { this.copy(paramMap).evaluate(dataset) } @@ -46,8 +46,8 @@ abstract class Evaluator extends Params { * @param dataset a dataset that contains labels/observations and predictions. * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame): Double + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): Double /** * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 55ff44323a79..3acfc221c95e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid setDefault(metricName -> "f1") - @Since("1.5.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 9976d7ed43a6..ed04b67bcc93 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType} @@ -39,11 +39,12 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * Param for metric name in evaluation. Supports: + * - `"rmse"` (default): root mean squared error + * - `"mse"`: mean squared error + * - `"r2"`: R^2^ metric + * - `"mae"`: mean absolute error * - * Because we will maximize evaluation value (ref: `CrossValidator`), - * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), - * we take and output the negative of this metric. * @group param */ @Since("1.4.0") @@ -70,8 +71,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui setDefault(metricName -> "rmse") - @Since("1.4.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2f8e3a0371a4..898ac2cc8941 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -64,7 +64,8 @@ final class Binarizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 33abc7c99d4b..10e622ace6d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val bucketizer = udf { feature: Double => Bucketizer.binarySearchForBuckets($(splits), feature) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index b9e9d5685360..cfecae7e0b15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String) /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def fit(dataset: DataFrame): ChiSqSelectorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { case Row(label: Double, features: Vector) => @@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] ( /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last val selector = udf { chiSqSelector.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 5694b3890fba..922670a41b6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -100,6 +100,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** @group getParam */ def getMinTF: Double = $(minTF) + + /** + * Binary toggle to control the output vector values. + * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for + * discrete probabilistic models that model binary events rather than integer counts. + * Default: false + * @group param + */ + val binary: BooleanParam = + new BooleanParam(this, "binary", "If True, all non zero counts are set to 1.") + + /** @group getParam */ + def getBinary: Boolean = $(binary) + + setDefault(binary -> false) } /** @@ -127,9 +142,13 @@ class CountVectorizer(override val uid: String) /** @group setParam */ def setMinTF(value: Double): this.type = set(minTF, value) + /** @group setParam */ + def setBinary(value: Boolean): this.type = set(binary, value) + setDefault(vocabSize -> (1 << 18), minDF -> 1) - override def fit(dataset: DataFrame): CountVectorizerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) @@ -152,16 +171,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) @@ -206,30 +219,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin /** @group setParam */ def setMinTF(value: Double): this.type = set(minTF, value) - /** - * Binary toggle to control the output vector values. - * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for - * discrete probabilistic models that model binary events rather than integer counts. - * Default: false - * @group param - */ - val binary: BooleanParam = - new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " + - "This is useful for discrete probabilistic models that model binary events rather " + - "than integer counts") - - /** @group getParam */ - def getBinary: Boolean = $(binary) - /** @group setParam */ def setBinary(value: Boolean): this.type = set(binary, value) - setDefault(binary -> false) - /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0f7ae5a10035..467ad7307462 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidat import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} @@ -77,7 +77,8 @@ class HashingTF(override val uid: String) /** @group setParam */ def setBinary(value: Boolean): this.type = set(binary, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) val t = udf { terms: Seq[_] => hashingTF.transform(terms) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f36cf503a0b8..5075b78c9856 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - override def fit(dataset: DataFrame): IDFModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) @@ -115,7 +116,8 @@ class IDFModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val idf = udf { vec: Vector => idfModel.transform(vec) } dataset.withColumn($(outputCol), idf(col($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index d3fe6e528f0b..9ca34e9ae22f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 7de5a4d5d314..e9df600c8a99 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): MaxAbsScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b13684a1cb76..125becbb8a5b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String) /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def fit(dataset: DataFrame): MinMaxScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4f67042629c5..99357793dbae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} @@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 305c3d187fcb..9cf722e12169 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. */ - override def fit(dataset: DataFrame): PCAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) @@ -124,7 +125,8 @@ class PCAModel private[ml] ( * NOTE: Vectors to be transformed must be the same length * as the source vectors given to [[PCA.fit()]]. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val pcaModel = new feature.PCAModel($(k), pc, explainedVariance) val pcaOp = udf { pcaModel.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index e486e92c12aa..5c7993af645a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom @@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol with HasSeed { /** - * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must + * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. * default: 2 * @group param @@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + + /** + * Relative error (see documentation for + * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description) + * Must be a number in [0, 1]. + * default: 0.001 + * @group param + */ + val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + + "for approxQuantile", + ParamValidators.inRange(0.0, 1.0)) + setDefault(relativeError -> 0.001) + + /** @group getParam */ + def getRelativeError: Double = getOrDefault(relativeError) } /** @@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, - * covering all real values. This attempts to find numBuckets partitions based on a sample of data, - * but it may find fewer depending on the data sample values. + * covering all real values. */ @Experimental final class QuantileDiscretizer(override val uid: String) @@ -65,6 +79,9 @@ final class QuantileDiscretizer(override val uid: String) def this() = this(Identifiable.randomUID("quantileDiscretizer")) + /** @group setParam */ + def setRelativeError(value: Double): this.type = set(relativeError, value) + /** @group setParam */ def setNumBuckets(value: Int): this.type = set(numBuckets, value) @@ -87,12 +104,13 @@ final class QuantileDiscretizer(override val uid: String) StructType(outputFields) } - override def fit(dataset: DataFrame): Bucketizer = { - val samples = QuantileDiscretizer - .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) - .map { case Row(feature: Double) => feature } - val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) - val splits = QuantileDiscretizer.getSplits(candidates) + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Bucketizer = { + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + splits(0) = Double.NegativeInfinity + splits(splits.length - 1) = Double.PositiveInfinity + val bucketizer = new Bucketizer(uid).setSplits(splits) copyValues(bucketizer.setParent(this)) } @@ -103,90 +121,6 @@ final class QuantileDiscretizer(override val uid: String) @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - /** - * Minimum number of samples required for finding splits, regardless of number of bins. If - * the dataset has fewer rows than this value, the entire dataset will be used. - */ - private[spark] val minSamplesRequired: Int = 10000 - - /** - * Sampling from the given dataset to collect quantile statistics. - */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { - val totalSamples = dataset.count() - require(totalSamples > 0, - "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, minSamplesRequired) - val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() - } - - /** - * Compute split points with respect to the sample distribution. - */ - private[feature] - def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { - val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) - } - val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1)) - val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.dropRight(1).map(_._1) - } else { - val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1)) - val splitsBuilder = mutable.ArrayBuilder.make[Double] - var index = 1 - // currentCount: sum of counts of values that have been visited - var currentCount = valueCounts(0)._2 - // targetCount: target value for `currentCount`. If `currentCount` is closest value to - // `targetCount`, then current value is a split threshold. After finding a split threshold, - // `targetCount` is added by stride. - var targetCount = stride - while (index < valueCounts.length) { - val previousCount = currentCount - currentCount += valueCounts(index)._2 - val previousGap = math.abs(previousCount - targetCount) - val currentGap = math.abs(currentCount - targetCount) - // If adding count of current value to currentCount makes the gap between currentCount and - // targetCount smaller, previous value is a split threshold. - if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 - targetCount += stride - } - index += 1 - } - splitsBuilder.result() - } - } - - /** - * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as - * needed, and adding a default split value of 0 if no good candidates are found. - */ - private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { - val effectiveValues = if (candidates.nonEmpty) { - if (candidates.head == Double.NegativeInfinity - && candidates.last == Double.PositiveInfinity) { - candidates.drop(1).dropRight(1) - } else if (candidates.head == Double.NegativeInfinity) { - candidates.drop(1) - } else if (candidates.last == Double.PositiveInfinity) { - candidates.dropRight(1) - } else { - candidates - } - } else { - candidates - } - - if (effectiveValues.isEmpty) { - Array(Double.NegativeInfinity, 0, Double.PositiveInfinity) - } else { - Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 12a76dbbfb4b..3ac6c776699a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types._ /** @@ -103,7 +103,8 @@ class RFormula(override val uid: String) RFormulaParser.parse($(formula)).hasIntercept } - override def fit(dataset: DataFrame): RFormulaModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): RFormulaModel = { require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) @@ -204,7 +205,8 @@ class RFormulaModel private[feature]( private[ml] val pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase with MLWritable { - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { checkCanTransform(dataset.schema) transformLabel(pipelineModel.transform(dataset)) } @@ -232,10 +234,10 @@ class RFormulaModel private[feature]( override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" - private def transformLabel(dataset: DataFrame): DataFrame = { + private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { - dataset + dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => @@ -246,7 +248,7 @@ class RFormulaModel private[feature]( } else { // Ignore the label field. This is a hack so that this transformer can also work on test // datasets in a Pipeline. - dataset + dataset.toDF } } @@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str def this(columnsToPrune: Set[String]) = this(Identifiable.randomUID("columnPruner"), columnsToPrune) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) dataset.select(columnsToKeep.map(dataset.col): _*) } @@ -396,7 +398,7 @@ private class VectorAttributeRewriter( def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val metadata = { val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index e0ca45b9a619..2002d15745d9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.types.StructType /** @@ -63,13 +63,12 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor private val tableIdentifier: String = "__THIS__" - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val outputDF = dataset.sqlContext.sql(realStatement) - outputDF + dataset.sqlContext.sql(realStatement) } @Since("1.6.0") @@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) - dummyDF.registerTempTable(tableIdentifier) - val outputSchema = sqlContext.sql($(statement)).schema + val tableName = Identifiable.randomUID(uid) + val realStatement = $(statement).replace(tableIdentifier, tableName) + dummyDF.registerTempTable(tableName) + val outputSchema = sqlContext.sql(realStatement).schema + sqlContext.dropTempTable(tableName) outputSchema } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 26ee8e1bf166..118a6e3e6ad4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - override def fit(dataset: DataFrame): StandardScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) @@ -135,7 +136,8 @@ class StandardScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0a0e0b0960c8..b96bc48566fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructType} @@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String) setDefault(stopWords -> StopWords.English, caseSensitive -> false) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index faa0f6f407b3..7e0d374f0272 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): StringIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -144,11 +145,12 @@ class StringIndexerModel ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + "Skip StringIndexerModel.") - return dataset + return dataset.toDF } validateAndTransformSchema(dataset.schema) @@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String) StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if ($(labels).isEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 957e8e7a5983..4d3e46e488c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Schema transformation. val schema = dataset.schema - lazy val first = dataset.first() + lazy val first = dataset.toDF.first() val attrs = $(inputCols).flatMap { c => val field = schema(c) val index = schema.fieldIndex(c) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index bf4aef2a74c7..68b699d569c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): VectorIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") @@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index b60e82de00c0..7a9468b87b73 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType @@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Validity checks transformSchema(dataset.schema) val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 95bae1c8a312..a72692960f92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) - override def fit(dataset: DataFrame): Word2VecModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() @@ -219,7 +220,8 @@ class Word2VecModel private[ml] ( * Transform a sentence column to a vector column to represent the whole sentence. The transform * is performed by averaging all word vectors it contains. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 40590e71c42a..783546862689 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class AFTSurvivalRegressionWrapper private ( pipeline: PipelineModel, @@ -43,8 +43,8 @@ private[r] class AFTSurvivalRegressionWrapper private ( features ++ Array("Log(scale)") } - def transform(dataset: DataFrame): DataFrame = { - pipeline.transform(dataset) + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(aftModel.getFeaturesCol) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala new file mode 100644 index 000000000000..475a3083854e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression._ +import org.apache.spark.sql._ + +private[r] class GeneralizedLinearRegressionWrapper private ( + pipeline: PipelineModel, + val features: Array[String]) { + + private val glm: GeneralizedLinearRegressionModel = + pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] + + lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray + } else { + glm.coefficients.toArray + } + + lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(glm.getFeaturesCol) + } +} + +private[r] object GeneralizedLinearRegressionWrapper { + + def fit( + formula: String, + data: DataFrame, + family: String, + link: String, + epsilon: Double, + maxit: Int): GeneralizedLinearRegressionWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + val rFormulaModel = rFormula.fit(data) + // get labels and feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline + val glm = new GeneralizedLinearRegression() + .setFamily(family) + .setLink(link) + .setFitIntercept(rFormula.hasIntercept) + .setTol(epsilon) + .setMaxIter(maxit) + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, glm)) + .fit(data) + new GeneralizedLinearRegressionWrapper(pipeline, features) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index ed735a4ea399..9e2b81ee2014 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class KMeansWrapper private ( pipeline: PipelineModel) { @@ -37,7 +37,7 @@ private[r] class KMeansWrapper private ( lazy val k: Int = kMeansModel.getK - lazy val size: Array[Int] = kMeansModel.summary.clusterSizes + lazy val size: Array[Long] = kMeansModel.summary.clusterSizes lazy val cluster: DataFrame = kMeansModel.summary.cluster @@ -52,7 +52,7 @@ private[r] class KMeansWrapper private ( } } - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 07383d393d63..b17207e99bb8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class NaiveBayesWrapper private ( pipeline: PipelineModel, @@ -36,8 +36,10 @@ private[r] class NaiveBayesWrapper private ( lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) - def transform(dataset: DataFrame): DataFrame = { - pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(naiveBayesModel.getFeaturesCol) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala deleted file mode 100644 index 551e75dc0a02..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.api.r - -import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.sql.DataFrame - -private[r] object SparkRWrappers { - def fitRModelFormula( - value: String, - df: DataFrame, - family: String, - lambda: Double, - alpha: Double, - standardize: Boolean, - solver: String): PipelineModel = { - val formula = new RFormula().setFormula(value) - val estimator = family match { - case "gaussian" => new LinearRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - .setSolver(solver) - case "binomial" => new LogisticRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - } - val pipeline = new Pipeline().setStages(Array(formula, estimator)) - pipeline.fit(df) - } - - def getModelCoefficients(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => { - val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ - m.summary.coefficientStandardErrors.dropRight(1) - val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) - val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ - tValuesR ++ pValuesR - } else { - m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR - } - } - case m: LogisticRegressionModel => { - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray - } else { - m.coefficients.toArray - } - } - } - } - - def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - m.summary.devianceResiduals - case m: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No deviance residuals available for LogisticRegressionModel") - } - } - - def getModelFeatures(model: PipelineModel): Array[String] = { - model.stages.last match { - case m: LinearRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } - case m: LogisticRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } - } - } - - def getModelName(model: PipelineModel): String = { - model.stages.last match { - case m: LinearRegressionModel => - "LinearRegressionModel" - case m: LogisticRegressionModel => - "LogisticRegressionModel" - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4a3ad662a0d3..36dce015908e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -200,8 +200,8 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - @Since("1.3.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] this } - @Since("1.3.0") - override def fit(dataset: DataFrame): ALSModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ALSModel = { import dataset.sqlContext.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3278974954ed..89ba6ab5d277 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -31,8 +31,9 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -183,7 +184,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. */ - protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) .rdd.map { case Row(features: Vector, label: Double, censor: Double) => @@ -191,17 +192,27 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } } - @Since("1.6.0") - override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val costFun = new AFTCostFun(instances, $(fitIntercept)) + val featuresSummarizer = { + val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) + val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { + c1.merge(c2) + } + instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + + val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + val numFeatures = featuresStd.size /* The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter @@ -230,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val rawCoefficients = parameters.slice(2, parameters.length) + var i = 0 + while (i < numFeatures) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } + val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) @@ -299,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] ( math.exp(BLAS.dot(coefficients, features) + intercept) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} @@ -434,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * @param parameters including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. + * @param featuresStd The standard deviation values of the features. */ -private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) - extends Serializable { +private class AFTAggregator( + parameters: BDV[Double], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends Serializable { // the regression coefficients to the covariates private val coefficients = parameters.slice(2, parameters.length) - private val intercept = parameters.valueAt(1) + private val intercept = parameters(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) - private var gradientInterceptSum = 0.0 - private var gradientLogSigmaSum = 0.0 + // Here we optimize loss function over log(sigma), intercept and coefficients + private val gradientSumArray = Array.ofDim[Double](parameters.length) def count: Long = totalCnt + def loss: Double = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + lossSum / totalCnt + } + def gradient: BDV[Double] = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + new BDV(gradientSumArray.map(_ / totalCnt.toDouble)) + } - def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - - // Here we optimize loss function over coefficients, intercept and log(sigma) - def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -466,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def add(data: AFTPoint): this.type = { - - val interceptFlag = if (fitIntercept) 1.0 else 0.0 - - val xi = data.features.toBreeze + val xi = data.features val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma - lossSum += math.log(sigma) * delta - lossSum += (math.exp(epsilon) - delta * epsilon) + val margin = { + var sum = 0.0 + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / featuresStd(index)) + } + } + sum + intercept + } + val epsilon = (math.log(ti) - margin) / sigma - // Sanity check (should never occur): - assert(!lossSum.isInfinity, - s"AFTAggregator loss sum is infinity. Error for unknown reason.") + lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon) - val deltaMinusExpEps = delta - math.exp(epsilon) - gradientCoefficientSum += xi * deltaMinusExpEps / sigma - gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma - gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon + val multiplier = (delta - math.exp(epsilon)) / sigma + + gradientSumArray(0) += delta + multiplier * sigma * epsilon + gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + } + } totalCnt += 1 this @@ -503,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientCoefficientSum += other.gradientCoefficientSum - gradientInterceptSum += other.gradientInterceptSum - gradientLogSigmaSum += other.gradientLogSigmaSum + var i = 0 + val len = this.gradientSumArray.length + while (i < len) { + this.gradientSumArray(i) += other.gradientSumArray(i) + i += 1 + } } this } @@ -516,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * It returns the loss and gradient at a particular point (parameters). * It's used in Breeze's convex optimization routines. */ -private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) - extends DiffFunction[BDV[Double]] { +private class AFTCostFun( + data: RDD[AFTPoint], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { - val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( + val aftAggregator = data.treeAggregate( + new AFTAggregator(parameters, fitIntercept, featuresStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 1289a317ee7f..c04c416aaf19 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val /** @group setParam */ def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) transformImpl(dataset) } - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset + var output = dataset.toDF if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index cef7c643d7bf..741724d7a104 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -18,22 +18,23 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, - SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -41,12 +42,24 @@ import org.apache.spark.sql.functions._ * [[http://en.wikipedia.org/wiki/Gradient_boosting Gradient-Boosted Trees (GBTs)]] * learning algorithm for regression. * It supports both continuous and categorical features. + * + * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. + * - We expect to implement TreeBoost in the future: + * [https://issues.apache.org/jira/browse/SPARK-4240] */ @Since("1.4.0") @Experimental final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] - with GBTParams with TreeRegressorParams with Logging { + with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) @@ -100,42 +113,13 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTRegressor: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "squared") + // Parameters from GBTRegressorParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") - } - } - - override protected def train(dataset: DataFrame): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -152,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") @Experimental -object GBTRegressor { - // The losses below should be lowercase. +object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTRegressor = super.load(path) } /** @@ -176,7 +163,8 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + @@ -197,7 +185,7 @@ final class GBTRegressionModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -243,12 +231,64 @@ final class GBTRegressionModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } -private[ml] object GBTRegressionModel { +@Since("2.0.0") +object GBTRegressionModel extends MLReadable[GBTRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GBTRegressionModel = super.load(path) + + private[GBTRegressionModel] + class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + + require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index a40d3731cbfc..e92a3e7fa1f0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -165,7 +165,11 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val setDefault(tol -> 1E-6) /** - * Sets the regularization parameter. + * Sets the regularization parameter for L2 regularization. + * The regularization term is + * {{{ + * 0.5 * regParam * L2norm(coefficients)^2 + * }}} * Default is 0.0. * @group setParam */ @@ -192,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "irls") - override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyObj = Family.fromName($(family)) val linkObj = if (isDefined(link)) { Link.fromName($(link)) @@ -233,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, wlsModel.diagInvAtWA.toArray, - 1) + 1, + getSolver) return model.setSummary(trainingSummary) } @@ -253,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, irlsModel.diagInvAtWA.toArray, - irlsModel.numIterations) + irlsModel.numIterations, + getSolver) model.setSummary(trainingSummary) } @@ -772,11 +778,12 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr * :: Experimental :: * Summarizing Generalized Linear regression Fits. * - * @param predictions predictions outputted by the model's `transform` method + * @param predictions predictions output by the model's `transform` method * @param predictionCol field in "predictions" which gives the prediction value of each instance * @param model the model that should be summarized * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration * @param numIterations number of iterations + * @param solver the solver algorithm used for model training */ @Since("2.0.0") @Experimental @@ -785,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") val predictionCol: String, @Since("2.0.0") val model: GeneralizedLinearRegressionModel, private val diagInvAtWA: Array[Double], - @Since("2.0.0") val numIterations: Int) extends Serializable { + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) extends Serializable { import GeneralizedLinearRegression._ @@ -933,6 +941,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val coefficientStandardErrors: Array[Double] = { @@ -941,6 +952,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * T-statistic of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val tValues: Array[Double] = { @@ -954,6 +968,9 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. + * + * If [[GeneralizedLinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. */ @Since("2.0.0") lazy val pValues: Array[Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index bd0b631d897b..7a78ecbdf16d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Extracts (label, feature, weight) from input dataset. */ protected[ml] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + dataset: Dataset[_]): RDD[(Double, Double, Double)] = { val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) val extract = udf { v: Vector => v(idx) } @@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - @Since("1.5.0") - override def fit(dataset: DataFrame): IsotonicRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] ( copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 9619e72a4594..71e02730c721 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -38,7 +38,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "auto") - override protected def train(dataset: DataFrame): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size @@ -417,7 +417,7 @@ class LinearRegressionModel private[ml] ( * @param dataset Test dataset to evaluate model on. */ @Since("2.0.0") - def evaluate(dataset: DataFrame): LinearRegressionSummary = { + def evaluate(dataset: Dataset[_]): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, @@ -513,7 +513,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { * Linear regression training results. Currently, the training summary ignores the * training weights except for the objective trace. * - * @param predictions predictions outputted by the model's `transform` method. + * @param predictions predictions output by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. */ @Since("1.5.0") @@ -549,7 +549,7 @@ class LinearRegressionTrainingSummary private[regression] ( * :: Experimental :: * Linear regression results evaluated on a dataset. * - * @param predictions predictions outputted by the model's `transform` method. + * @param predictions predictions output by the model's `transform` method. * @param predictionCol Field in "predictions" which gives the predicted value of the label at * each instance. * @param labelCol Field in "predictions" which gives the true label of each instance. @@ -655,8 +655,11 @@ class LinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. - * * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * * @see [[LinearRegression.solver]] */ lazy val coefficientStandardErrors: Array[Double] = { @@ -679,8 +682,11 @@ class LinearRegressionSummary private[regression] ( /** * T-statistic of estimated coefficients and intercept. - * * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * * @see [[LinearRegression.solver]] */ lazy val tValues: Array[Double] = { @@ -699,8 +705,11 @@ class LinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. - * * This value is only available when using the "normal" solver. + * + * If [[LinearRegression.fitIntercept]] is set to true, + * then the last element returned corresponds to the intercept. + * * @see [[LinearRegression.solver]] */ lazy val pValues: Array[Double] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 736cd9f776bd..4c4ff278d4eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -93,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -164,7 +164,7 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) @@ -249,7 +249,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 2e9b6be9a26b..2f1f2523fd11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -26,11 +26,9 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.annotation.Since -import org.apache.spark.broadcast.Broadcast import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder @@ -40,7 +38,6 @@ import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFil import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet private[libsvm] class LibSVMOutputWriter( path: String, @@ -178,39 +175,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: This does not handle cases where column pruning has been performed. - - verifySchema(dataSchema) - val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - - val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString - else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data") - else throw new IOException("Multiple input paths are not supported for libsvm data.") - - val numFeatures = options.getOrElse("numFeatures", "-1").toInt - val vectorType = options.getOrElse("vectorType", "sparse") - - val sc = sqlContext.sparkContext - val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - val sparse = vectorType == "sparse" - baseRdd.map { pt => - val features = if (sparse) pt.features.toSparse else pt.features.toDense - Row(pt.label, features) - }.mapPartitions { externalRows => - val converter = RowEncoder(dataSchema) - externalRows.map(converter.toRow) - } - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, @@ -218,6 +182,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + verifySchema(dataSchema) val numFeatures = options("numFeatures").toInt assert(numFeatures > 0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index df8eb5d1f927..c7cde1563fc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging { } case _ => featureSubsetStrategy } + + val isIntRegex = "^([1-9]\\d*)$".r + val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r val numFeaturesPerNode: Int = _featureSubsetStrategy match { case "all" => numFeatures case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt + case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt + case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index d3656556742b..b6334762c7a7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -30,22 +30,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -/** - * A package that implements - * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] - * for regression and binary classification. - * - * The implementation is based upon: - * J.H. Friedman. "Stochastic Gradient Boosting." 1999. - * - * Notes on Gradient Boosting vs. TreeBoost: - * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. - * - Both algorithms learn tree ensembles by minimizing loss functions. - * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes - * based on the loss function, whereas the original gradient boosting method does not. - * - When the loss is SquaredError, these methods give the same result, but they could differ - * for other loss functions. - */ private[spark] object GradientBoostedTrees extends Logging { /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index c4ab673d9a7e..f38e1ec7c09a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite { sql: SQLContext, extraMetadata: JObject): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) - val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map { + val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { case (tree, treeID) => - treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext) + (treeID, + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext), + instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString - sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata") + sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights") .write.parquet(treesMetadataPath) val dataPath = new Path(path, "data").toString val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { @@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SQLContext, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)]) = { + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite { } val treesMetadataPath = new Path(path, "treesMetadata").toString - val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath) - .select("treeID", "metadata").as[(Int, String)].rdd.map { - case (treeID: Int, json: String) => - treeID -> DefaultParamsReader.parseMetadata(json, treeClassName) + val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { + case (treeID: Int, json: String, weights: Double) => + treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights) } - val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect() + + val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadata = treesMetadataWeights.map(_._1) + val treesWeights = treesMetadataWeights.map(_._2) val dataPath = new Path(path, "data").toString val nodeData: Dataset[EnsembleNodeData] = @@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite { treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() - (metadata, treesMetadata.zip(rootNodes)) + (metadata, treesMetadata.zip(rootNodes), treesWeights) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 78e6d3bfacb5..b6783911ad11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { * - "onethird": use 1/3 of the features * - "sqrt": use sqrt(number of features) * - "log2": use log2(number of features) + * - "n": when n is in the range (0, 1.0], use n * number of features. When n + * is in the range (1, number of features), use n features. * (default = "auto") * * These various settings are based on the following references: @@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { "The number of features to consider for splits at each tree node." + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex)) setDefault(featureSubsetStrategy -> "auto") @@ -393,6 +396,9 @@ private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + + // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features) + final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$" } private[ml] trait RandomForestClassifierParams @@ -456,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** Get old Gradient Boosting Loss type */ private[ml] def getOldLossType: OldLoss } + +private[ml] object GBTClassifierParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } +} + +private[ml] object GBTRegressorParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4d9d4d472e17..de563d4fad7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -90,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.4.0") - override def fit(dataset: DataFrame): CrossValidatorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -100,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -209,8 +209,8 @@ class CrossValidatorModel private[ml] ( this(uid, bestModel, avgMetrics.asScala.toArray) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0f2179c2a126..12d6905510c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.tuning import java.util.{List => JList} import scala.collection.JavaConverters._ +import scala.language.existentials import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -31,7 +32,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -89,8 +90,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): TrainValidationSplitModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -207,8 +208,8 @@ class TrainValidationSplitModel private[ml] ( this(uid, bestModel, validationMetrics.asScala.toArray) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala new file mode 100644 index 000000000000..7e57cefc4449 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.util.concurrent.atomic.AtomicLong + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.Param +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset + +/** + * A small wrapper that defines a training session for an estimator, and some methods to log + * useful information during this session. + * + * A new instance is expected to be created within fit(). + * + * @param estimator the estimator that is being fit + * @param dataset the training dataset + * @tparam E the type of the estimator + */ +private[ml] class Instrumentation[E <: Estimator[_]] private ( + estimator: E, dataset: RDD[_]) extends Logging { + + private val id = Instrumentation.counter.incrementAndGet() + private val prefix = { + val className = estimator.getClass.getSimpleName + s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + } + + init() + + private def init(): Unit = { + log(s"training: numPartitions=${dataset.partitions.length}" + + s" storageLevel=${dataset.getStorageLevel}") + } + + /** + * Logs a message with a prefix that uniquely identifies the training session. + */ + def log(msg: String): Unit = { + logInfo(prefix + msg) + } + + /** + * Logs the value of the given parameters for the estimator being used in this session. + */ + def logParams(params: Param[_]*): Unit = { + val pairs: Seq[(String, JValue)] = for { + p <- params + value <- estimator.get(p) + } yield { + val cast = p.asInstanceOf[Param[Any]] + p.name -> parse(cast.jsonEncode(value)) + } + log(compact(render(map2jvalue(pairs.toMap)))) + } + + def logNumFeatures(num: Long): Unit = { + log(compact(render("numFeatures" -> num))) + } + + def logNumClasses(num: Long): Unit = { + log(compact(render("numClasses" -> num))) + } + + /** + * Logs the successful completion of the training session and the value of the learned model. + */ + def logSuccess(model: Model[_]): Unit = { + log(s"training finished") + } +} + +/** + * Some common methods for logging information about a training session. + */ +private[ml] object Instrumentation { + private val counter = new AtomicLong(0) + + /** + * Creates an instrumentation object for a training session. + */ + def create[E <: Estimator[_]]( + estimator: E, dataset: Dataset[_]): Instrumentation[E] = { + create[E](estimator, dataset.rdd) + } + + /** + * Creates an instrumentation object for a training session. + */ + def create[E <: Estimator[_]]( + estimator: E, dataset: RDD[_]): Instrumentation[E] = { + new Instrumentation[E](estimator, dataset) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 03eb903bb8fe..f04c87259c94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -181,13 +181,12 @@ class GaussianMixture private ( val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - case None => { + case None => val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) - } } var llh = Double.MinValue // current log-likelihood diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 02417b112432..f87613cc72f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -183,7 +183,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val k = (metadata \ "k").extract[Int] val classNameV1_0 = SaveLoadV1_0.classNameV1_0 (loadedClassName, version) match { - case (classNameV1_0, "1.0") => { + case (classNameV1_0, "1.0") => val model = SaveLoadV1_0.load(sc, path) require(model.weights.length == k, s"GaussianMixtureModel requires weights of length $k " + @@ -192,7 +192,6 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { s"GaussianMixtureModel requires gaussians of length $k" + s"got gaussians of length ${model.gaussians.length}") model - } case _ => throw new Exception( s"GaussianMixtureModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 37a21cd879bf..8ff0b83e8b49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -253,16 +253,14 @@ class KMeans private ( } val centers = initialModel match { - case Some(kMeansCenters) => { + case Some(kMeansCenters) => Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) - } - case None => { + case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { initKMeansParallel(data) } - } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 12813fd412b1..d999b9be8e8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -130,7 +130,8 @@ class LDA private ( */ @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { - require(docConcentration.size > 0, "docConcentration must have > 0 elements") + require(docConcentration.size == 1 || docConcentration.size == k, + s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}") this.docConcentration = docConcentration this } @@ -260,15 +261,18 @@ class LDA private ( def getCheckpointInterval: Int = checkpointInterval /** - * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that + * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be * important when LDA is run for many iterations. If the checkpoint directory is not set in - * [[org.apache.spark.SparkContext]], this setting is ignored. + * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10) * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { + require(checkpointInterval == -1 || checkpointInterval > 0, + s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}") this.checkpointInterval = checkpointInterval this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 4eb8fc049e61..24e1cff0dcc6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + require(centers.size == weights.size, + "Number of initial centers must be equal to number of weights") + require(centers.size == k, + s"Number of initial centers must be ${k} but got ${centers.size}") + require(weights.forall(_ >= 0), + s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } @@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + require(dim > 0, + s"Number of dimensions must be positive but got ${dim}") + require(weight >= 0, + s"Weight for each center must be nonnegative but got ${weight}") val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) val weights = Array.fill(k)(weight) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 4455681e5076..4344ab1bade9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.internal.Logging +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -566,4 +576,88 @@ object PrefixSpan extends Logging { @Since("1.5.0") class PrefixSpanModel[Item] @Since("1.5.0") ( @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) - extends Serializable + extends Saveable with Serializable { + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[PrefixSpanModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + PrefixSpanModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + PrefixSpanModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel" + + def save(model: PrefixSpanModel[_], path: String): Unit = { + val sc = model.freqSequences.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqSequences.first().sequence(0)(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqSequences.map { x => + Row(x.sequence, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqSequences = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqSequences.select("sequence").head().get(0) + loadImpl(freqSequences, sample) + } + + def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = { + val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x => + val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray + val freq = x.getLong(1) + new PrefixSpan.FreqSequence(sequence, freq) + } + new PrefixSpanModel(freqSequencesRDD) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0f0c3a2df556..5812cdde2c42 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -186,7 +186,7 @@ sealed trait Vector extends Serializable { * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.DataFrame]]. + * via [[org.apache.spark.sql.Dataset]]. */ @AlphaComponent class VectorUDT extends UserDefinedType[Vector] { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index 0ec8975fed8f..9748fbf2c97b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -64,10 +64,11 @@ private[stat] object KolmogorovSmirnovTest extends Logging { */ def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = { val n = data.count().toDouble - val ksStat = data.sortBy(x => x).zipWithIndex().map { case (v, i) => - val f = cdf(v) - math.max(f - i / n, (i + 1) / n - f) - }.max() + val localData = data.sortBy(x => x).mapPartitions { part => + val partDiffs = oneSampleDifferences(part, n, cdf) // local distances + searchOneSampleCandidates(partDiffs) // candidates: local extrema + }.collect() + val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme evalOneSampleP(ksStat, n.toLong) } @@ -83,6 +84,74 @@ private[stat] object KolmogorovSmirnovTest extends Logging { testOneSample(data, cdf) } + /** + * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a + * partition + * @param partData `Iterator[Double]` 1 partition of a sorted RDD + * @param n `Double` the total size of the RDD + * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value + * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema + * in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF, + * the second element corresponds to empirical CDF - CDF. We can then search the resulting + * iterator for the minimum of the first and the maximum of the second element, and provide + * this as a partition's candidate extrema + */ + private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double) + : Iterator[(Double, Double)] = { + // zip data with index (within that partition) + // calculate local (unadjusted) empirical CDF and subtract CDF + partData.zipWithIndex.map { case (v, ix) => + // dp and dl are later adjusted by constant, when global info is available + val dp = (ix + 1) / n + val dl = ix / n + val cdfVal = cdf(v) + (dl - cdfVal, dp - cdfVal) + } + } + + /** + * Search the unadjusted differences in a partition and return the + * two extrema (furthest below and furthest above CDF), along with a count of elements in that + * partition + * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF + * and CDFin a partition, which come as a tuple of + * (empirical CDF - 1/N - CDF, empirical CDF - CDF) + * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements + */ + private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)]) + : Iterator[(Double, Double, Double)] = { + val initAcc = (Double.MaxValue, Double.MinValue, 0.0) + val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) => + (math.min(pMin, dl), math.max(pMax, dp), pCt + 1) + } + val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults) + results.iterator + } + + /** + * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after + * adjusting local extrema estimates from individual partitions with the amount of elements in + * preceding partitions + * @param localData `Array[(Double, Double, Double)]` A local array containing the collected + * results of `searchOneSampleCandidates` across all partitions + * @param n `Double`The size of the RDD + * @return The one-sample Kolmogorov Smirnov Statistic + */ + private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double) + : Double = { + val initAcc = (Double.MinValue, 0.0) + // adjust differences based on the number of elements preceding it, which should provide + // the correct distance between empirical CDF and CDF + val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) => + val adjConst = prevCt / n + val dist1 = math.abs(minCand + adjConst) + val dist2 = math.abs(maxCand + adjConst) + val maxVal = Array(prevMax, dist1, dist2).max + (maxVal, prevCt + ct) + } + results._1 + } + /** * A convenience function that allows running the KS test for 1 set of sample data against * a named distribution @@ -97,7 +166,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging { : KolmogorovSmirnovTestResult = { val distObj = distName match { - case "norm" => { + case "norm" => if (params.nonEmpty) { // parameters are passed, then can only be 2 require(params.length == 2, "Normal distribution requires mean and standard " + @@ -109,7 +178,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging { "initialized to standard normal (i.e. N(0, 1))") new NormalDistribution(0, 1) } - } case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" + s" convenience method. Current options are:['norm'].") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 1841fa4a95c9..26755849ad1a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -55,10 +55,15 @@ import org.apache.spark.util.Utils * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * @param featureSubsetStrategy Number of features to consider for splits at each node. * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * Supported numerical values: "(0.0-1.0]", "[1-n]". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. + * If a real value "n" in the range (0, 1.0] is set, + * use n * number of features. + * If an integer value "n" in the range (1, num features) is set, + * use n features. * @param seed Random seed for bootstrapping and choosing feature subsets. */ private class RandomForest ( @@ -70,9 +75,11 @@ private class RandomForest ( strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") - require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), + require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) + || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex), s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + - s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") + s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") /** * Method to train a decision tree model over an RDD diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 75061464e546..5aec52ac72b1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -22,6 +22,7 @@ import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -80,6 +81,24 @@ public void runDT() { for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestClassificationModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index b6f793f6de89..a8736669f72e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -22,6 +22,7 @@ import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -80,6 +81,24 @@ public void runDT() { for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestRegressionModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 34daf5fbde80..8a67793abc14 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; +import org.apache.spark.util.Utils; public class JavaPrefixSpanSuite { private transient JavaSparkContext sc; @@ -64,4 +66,39 @@ public void runPrefixSpan() { long freq = freqSeq.freq(); } } + + @Test + public void runPrefixSpanSaveLoad() { + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + JavaRDD> freqSeqs = newModel.freqSequences().toJavaRDD(); + List> localFreqSeqs = freqSeqs.collect(); + Assert.assertEquals(5, localFreqSeqs.size()); + // Check that each frequent sequence could be materialized. + for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + List> seq = freqSeq.javaSequence(); + long freq = freqSeq.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + + + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index f3321fb5a1ab..a8c4ac6d0561 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] + when(dataset0.toDF).thenReturn(dataset0) + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) when(model0.copy(any[ParamMap])).thenReturn(model0) when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) @@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with MLWritabl override def write: MLWriter = new DefaultParamsWriter(this) - override def transform(dataset: DataFrame): DataFrame = dataset + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF override def transformSchema(schema: StructType): StructType = schema } @@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer { override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) - override def transform(dataset: DataFrame): DataFrame = dataset + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 76d8c9372e9f..7e6aec6b1bb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTClassifierSuite.compareAPIs @@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) - val newModel = GBTClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTClassificationModel, + model2: GBTClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 7eefaf234662..48db4281309b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 06ff049b480a..80547fad6af8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -26,12 +26,12 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 4727cd436f4c..80a46fc70c75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -27,11 +27,11 @@ import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 41313967265b..f3e8fd11b296 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var rdd: RDD[LabeledPoint] = _ override def beforeAll(): Unit = { @@ -246,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 18f2c994b474..e641d79c1707 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 8edd44e5f1ef..1a274aea291f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c684bc11cccf..2ca386e4229c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit & transform") { + test("fit, transform, and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) @@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: KMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a1c93891c78b..17d6e9fc2ee7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} object LDASuite { @@ -64,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val k: Int = 5 val vocabSize: Int = 30 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -289,4 +289,15 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getCheckpointFiles.isEmpty) } + + test("EM LDA disable checkpointing") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3) + .setCheckpointInterval(-1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 04f165c5f1e7..7641e3b8cf66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), - (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), - (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), + (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => @@ -168,21 +169,34 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } } - test("CountVectorizerModel with binary") { + test("CountVectorizerModel and CountVectorizer with binary") { val df = sqlContext.createDataFrame(Seq( - (0, split("a a a b b c"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (0, split("a a a a b b b b c d"), + Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))), (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))) )).toDF("id", "words", "expected") - val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + // CountVectorizer test + val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .setBinary(true) + .fit(df) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => assert(features ~== expected absTol 1e-14) } + + // CountVectorizerModel test + val cv2 = new CountVectorizerModel(cv.vocabulary) + .setInputCol("words") + .setOutputCol("features") + .setBinary(true) + cv2.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } } test("CountVectorizer read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index 58fda29aa1e6..e4e15f43310b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -22,7 +22,7 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) @@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: DataFrame): Unit = { + def testNGram(t: NGram, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("nGrams", "wantedNGrams") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 25fabf64d559..8895d630a087 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,78 +17,60 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ - - test("Test quantile discretizer") { - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 10, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 4, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 3, - Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2), - Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity")) + test("Test observed number of buckets and their sizes match expected values") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 2, - Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), - Array("-Infinity, 2.0", "2.0, Infinity")) + val datasetSize = 100000 + val numBuckets = 5 + val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) - } + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") - test("Test getting splits") { - val splitTestPoints = Array( - Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity, Double.PositiveInfinity) - -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity), - Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity) - ) - for ((ori, res) <- splitTestPoints) { - assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) } + val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - test("Test splits on dataset larger than minSamplesRequired") { + test("Test transform method on unseen data") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 - val numBuckets = 5 - val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") - .setNumBuckets(numBuckets) - .setSeed(1) + .setNumBuckets(5) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count + val result = discretizer.fit(trainDF).transform(testDF) + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") } test("read/write") { @@ -98,34 +80,17 @@ class QuantileDiscretizerSuite .setNumBuckets(6) testDefaultReadWrite(t) } -} - -private object QuantileDiscretizerSuite extends SparkFunSuite { - def checkDiscretizedData( - sc: SparkContext, - data: Array[Double], - numBucket: Int, - expectedResult: Array[Double], - expectedAttrs: Array[String]): Unit = { + test("Verify resulting model has parent") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") - val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket).setSeed(1) + val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(5) val model = discretizer.fit(df) assert(model.hasParent) - val result = model.transform(df) - - val transformedFeatures = result.select("result").collect() - .map { case Row(transformedFeature: Double) => transformedFeature } - val transformedAttrs = Attribute.fromStructField(result.schema("result")) - .asInstanceOf[NominalAttribute].values.get - - assert(transformedFeatures === expectedResult, - "Transformed features do not equal expected features.") - assert(transformedAttrs === expectedAttrs, - "Transformed attributes do not equal expected attributes.") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 553e0b870216..e213e17d0d9d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -49,4 +50,13 @@ class SQLTransformerSuite .setStatement("select * from __THIS__") testDefaultReadWrite(t) } + + test("transformSchema") { + val df = sqlContext.range(10) + val outputSchema = new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transformSchema(df.schema) + val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) + assert(outputSchema === expected) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index a5b24c18565b..3505befdf8e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("filtered", "expected") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2c3255ef3336..d0f3cdc841d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -115,7 +115,7 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = sqlContext.range(0L, 10L).toDF() - assert(indexerModel.transform(df).eq(df)) + assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) } test("StringIndexerModel can't overwrite output column") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index 36e8e5d86838..299f6223b20b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) @@ -106,7 +106,7 @@ class RegexTokenizerSuite object RegexTokenizerSuite extends SparkFunSuite { - def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index f4844cc67118..76891ad56281 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ + @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite datasetMultivariate = sqlContext.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + datasetUnivariateScaled = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }) } /** @@ -356,6 +362,22 @@ class AFTSurvivalRegressionSuite } } + test("numerical stability of standardization") { + val trainer = new AFTSurvivalRegression() + val model1 = trainer.fit(datasetUnivariate) + val model2 = trainer.fit(datasetUnivariateScaled) + + /** + * During training we standardize the dataset first, so no matter how we multiple + * a scaling factor into the dataset, the convergence rate should be the same, + * and the coefficients should equal to the original coefficients multiple by + * the scaling factor. It will have no effect on the intercept and scale. + */ + assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01) + assert(model1.intercept ~== model2.intercept absTol 0.01) + assert(model1.scale ~== model2.scale absTol 0.01) + } + test("read/write") { def checkModelData( model: AFTSurvivalRegressionModel, diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 3c11631f9882..216377959e09 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -32,7 +32,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs @@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) - val newModel = GBTRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTRegressionModel, + model2: GBTRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTRegressorSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 2265464b51dd..3ecc210abdfc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: binomial family with weight") { @@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: poisson family with weight") { @@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: gamma family with weight") { @@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("read/write") { @@ -992,6 +996,14 @@ class GeneralizedLinearRegressionSuite assert(expected.coefficients === actual.coefficients) } } + + test("glm accepts Dataset[LabeledPoint]") { + val context = sqlContext + import context.implicits._ + new GeneralizedLinearRegression() + .setFamily("gaussian") + .fit(datasetGaussianIdentity.as[LabeledPoint]) + } } object GeneralizedLinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e64551f03c92..6db9ce150d93 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -327,7 +327,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { case n: InternalNode => n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) + case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") } + case _ => throw new AssertionError("model.rootNode was not an InternalNode") } } @@ -352,6 +354,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(n.leftChild.isInstanceOf[InternalNode]) assert(n.rightChild.isInstanceOf[InternalNode]) Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => throw new AssertionError("rootNode was not an InternalNode") } // Single group second level tree construction. @@ -423,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0") + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val integerStrategies = Array("1", "10", "100", "1000", "10000") + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") + for (invalidStrategy <- invalidStrategies) { + intercept[MatchError]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) + } + } + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + for (invalidStrategy <- invalidStrategies) { + intercept[MatchError]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) + } + } } test("Binary classification with continuous features: subsampling features") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7af3c6d6ede4..3e734aabc554 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.{StructField, StructType} class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } @@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4030956fabea..dbee47c8475d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite @@ -158,7 +158,7 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } @@ -172,7 +172,7 @@ object TrainValidationSplitSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 16280473c6ac..7ebd7eb14463 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} trait DefaultReadWriteTest extends TempDirectory { self: Suite => @@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, - dataset: DataFrame, + dataset: Dataset[_], testParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index a83e543859b8..6d8c7b47d837 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("model save/load") { + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val newModel = PrefixSpanModel.load(sc, path) + val originalSet = model.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + val newSet = newModel.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + assert(originalSet === newSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { diff --git a/pom.xml b/pom.xml index f37a8988f726..a772d513372e 100644 --- a/pom.xml +++ b/pom.xml @@ -94,6 +94,7 @@ core graphx mllib + mllib-local tools streaming sql/catalyst @@ -139,7 +140,7 @@ 1.6.0 8.1.14.v20131031 3.0.0.v201112011016 - 0.7.4 + 0.8.0 2.4.0 2.0.8 3.1.2 @@ -277,31 +278,11 @@ com.twitter chill_${scala.binary.version} ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - + + junit + junit + + + org.apache.commons + commons-math3 + + + + + org.json4s + json4s-jackson_${scala.binary.version} + 3.2.10 + com.sun.jersey jersey-json @@ -685,7 +688,7 @@ com.spotify docker-client shaded - 3.4.0 + 3.6.6 test @@ -2273,7 +2276,7 @@ false true false - ${basedir}/src/main/java + ${basedir}/src/main/java,${basedir}/src/main/scala ${basedir}/src/test/java dev/checkstyle.xml ${basedir}/target/checkstyle-output.xml diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a53161dc9a38..71f337ce1f63 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -329,6 +329,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), + // [SPARK-14451][SQL] Move encoder definition into Aggregator interface + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") @@ -614,9 +619,18 @@ object MimaExcludes { ) ++ Seq( // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") + ) ++ Seq( + // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this") ) ++ Seq( // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") + ) ++ Seq( + // [SPARK-14475] Propagate user-defined context from driver to executors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), + // [SPARK-14617] Remove deprecated APIs in TaskMetrics + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$") ) case v if v.startsWith("1.6") => Seq( @@ -836,7 +850,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), @@ -845,10 +858,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), @@ -859,7 +870,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), @@ -873,7 +883,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 60124ef0a13b..a58dd7e7f125 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -47,9 +47,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* ) = Seq( - "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "test-tags", "sketch" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects @@ -254,7 +254,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, testTags, sketch + unsafe, testTags, sketch, mllibLocal ).contains(x) } @@ -366,8 +366,10 @@ object Flume { object DockerIntegrationTests { // This serves to override the override specified in DependencyOverrides: lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "18.0" + dependencyOverrides += "com.google.guava" % "guava" % "18.0", + resolvers ++= Seq("DB2" at "https://app.camunda.com/nexus/content/repositories/public/") ) + } /** diff --git a/python/docs/Makefile b/python/docs/Makefile index 903009790ba3..905e0215c20c 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -2,10 +2,10 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -BUILDDIR = _build +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +PAPER ?= +BUILDDIR ?= _build export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.2-src.zip) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 529d16b48039..cb15b4b91f91 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -428,15 +428,19 @@ def f(split, iterator): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(c): - c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) - serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - serializer.dump_stream(c, tempFile) - tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + try: + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + serializer.dump_stream(c, tempFile) + tempFile.close() + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + finally: + # readRDDFromFile eagerily reads the file so we can delete right after. + os.unlink(tempFile.name) return RDD(jrdd, self, serializer) def pickleFile(self, name, minPartitions=None): diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d98919b3c639..922f8069fac4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,7 +19,7 @@ from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( @@ -272,7 +272,7 @@ def evaluate(self, dataset): return BinaryLogisticRegressionSummary(java_blr_summary) -class LogisticRegressionSummary(JavaCallable): +class LogisticRegressionSummary(JavaWrapper): """ Abstraction for Logistic Regression Results for a given model. @@ -291,7 +291,7 @@ def predictions(self): @since("2.0.0") def probabilityCol(self): """ - Field in "predictions" which gives the calibrated probability + Field in "predictions" which gives the probability of each class as a vector. """ return self._call_java("probabilityCol") diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf45d..4b0bade10280 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ def isLargerBetter(self): @inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): +class JavaEvaluator(JavaParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86b53285b5b0..809a513316f9 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -256,24 +256,33 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) Set the params for the CountVectorizer """ kwargs = self.setParams._input_kwargs @@ -324,6 +333,21 @@ def getVocabSize(self): """ return self.getOrDefault(self.vocabSize) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + def _create_model(self, java_model): return CountVectorizerModel(java_model) @@ -512,14 +536,19 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java .. versionadded:: 1.3.0 """ + binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " + + "This is useful for discrete probabilistic models that model binary events " + + "rather than integer counts. Default False.", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None): """ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) - self._setDefault(numFeatures=1 << 18) + self._setDefault(numFeatures=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -533,6 +562,21 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2b5504bc2966..9d654e8b0f8d 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import inherit_doc @@ -177,7 +177,7 @@ def _from_java(cls, java_stage): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage @@ -195,7 +195,7 @@ def _to_java(self): for idx, stage in enumerate(self.getStages()): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) return _java_obj @@ -275,7 +275,7 @@ def _from_java(cls, java_stage): Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) @@ -295,6 +295,6 @@ def _to_java(self): java_stages[idx] = stage._to_java() _java_obj =\ - JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f6c5d130dd85..c064fe500c3c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,7 +20,7 @@ from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -28,6 +28,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', + 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', @@ -187,7 +188,7 @@ def evaluate(self, dataset): return LinearRegressionSummary(java_lr_summary) -class LinearRegressionSummary(JavaCallable): +class LinearRegressionSummary(JavaWrapper): """ .. note:: Experimental @@ -331,6 +332,9 @@ def coefficientStandardErrors(self): Standard error of estimated coefficients and intercept. This value is only available when using the "normal" solver. + If :py:attr:`LinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + .. seealso:: :py:attr:`LinearRegression.solver` """ return self._call_java("coefficientStandardErrors") @@ -342,6 +346,9 @@ def tValues(self): T-statistic of estimated coefficients and intercept. This value is only available when using the "normal" solver. + If :py:attr:`LinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + .. seealso:: :py:attr:`LinearRegression.solver` """ return self._call_java("tValues") @@ -353,6 +360,9 @@ def pValues(self): Two-sided p-value of estimated coefficients and intercept. This value is only available when using the "normal" solver. + If :py:attr:`LinearRegression.fitIntercept` is set to True, + then the last element returned corresponds to the intercept. + .. seealso:: :py:attr:`LinearRegression.solver` """ return self._call_java("pValues") @@ -1188,6 +1198,150 @@ def predict(self, features): return self._call_java("predict", features) +@inherit_doc +class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol, + HasSolver, JavaMLWritable, JavaMLReadable): + """ + Generalized Linear Regression. + + Fit a Generalized Linear Model specified by giving a symbolic description of the linear + predictor (link function) and a description of the error distribution (family). It supports + "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family + is listed below. The first link function of each family is the default one. + - "gaussian" -> "identity", "log", "inverse" + - "binomial" -> "logit", "probit", "cloglog" + - "poisson" -> "log", "identity", "sqrt" + - "gamma" -> "inverse", "identity", "log" + + .. seealso:: `GLM `_ + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(0.0, 0.0)), + ... (1.0, Vectors.dense(1.0, 2.0)), + ... (2.0, Vectors.dense(0.0, 0.0)), + ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) + >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity") + >>> model = glr.fit(df) + >>> abs(model.transform(df).head().prediction - 1.5) < 0.001 + True + >>> model.coefficients + DenseVector([1.5..., -1.0...]) + >>> abs(model.intercept - 1.5) < 0.001 + True + >>> glr_path = temp_path + "/glr" + >>> glr.save(glr_path) + >>> glr2 = GeneralizedLinearRegression.load(glr_path) + >>> glr.getFamily() == glr2.getFamily() + True + >>> model_path = temp_path + "/glr_model" + >>> model.save(model_path) + >>> model2 = GeneralizedLinearRegressionModel.load(model_path) + >>> model.intercept == model2.intercept + True + >>> model.coefficients[0] == model2.coefficients[0] + True + + .. versionadded:: 2.0.0 + """ + + family = Param(Params._dummy(), "family", "The name of family which is a description of " + + "the error distribution to be used in the model. Supported options: " + + "gaussian(default), binomial, poisson and gamma.") + link = Param(Params._dummy(), "link", "The name of link function which provides the " + + "relationship between the linear predictor and the mean of the distribution " + + "function. Supported options: identity, log, inverse, logit, probit, cloglog " + + "and sqrt.") + + @keyword_only + def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + """ + super(GeneralizedLinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + Sets params for generalized linear regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GeneralizedLinearRegressionModel(java_model) + + @since("2.0.0") + def setFamily(self, value): + """ + Sets the value of :py:attr:`family`. + """ + self._paramMap[self.family] = value + return self + + @since("2.0.0") + def getFamily(self): + """ + Gets the value of family or its default value. + """ + return self.getOrDefault(self.family) + + @since("2.0.0") + def setLink(self, value): + """ + Sets the value of :py:attr:`link`. + """ + self._paramMap[self.link] = value + return self + + @since("2.0.0") + def getLink(self): + """ + Gets the value of link or its default value. + """ + return self.getOrDefault(self.link) + + +class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + Model fitted by GeneralizedLinearRegression. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("2.0.0") + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + if __name__ == "__main__": import doctest import pyspark.ml.regression diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2dcd5eeb52c2..86c0254a2b7b 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -52,7 +52,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -406,6 +406,22 @@ def test_stopwordsremover(self): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) + def test_count_vectorizer_with_binary(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + class HasInducedError(Params): @@ -644,7 +660,7 @@ def _compare_pipelines(self, m1, m2): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaWrapper): + if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) @@ -831,6 +847,25 @@ def test_logistic_regression_summary(self): self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) +class HashingTFTest(PySparkTestCase): + + def test_apply_binary_term_freqs(self): + sqlContext = SQLContext(self.sc) + + df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 100 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.sparse(n, {(ord("a") % n): 1.0, + (ord("b") % n): 1.0, + (ord("c") % n): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index da00f317b348..456d79d897e0 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ def _from_java_impl(cls, java_stage): """ # Load information from java_stage to the instance. - estimator = JavaWrapper._from_java(java_stage.getEstimator()) - evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,7 +329,7 @@ def _to_java(self): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -393,7 +393,7 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -410,10 +410,10 @@ def _to_java(self): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -574,8 +574,8 @@ def _to_java(self): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -588,6 +588,8 @@ def _to_java(self): class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ Model from train validation split. + + .. versionadded:: 2.0.0 """ def __init__(self, bestModel): @@ -637,7 +639,7 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -655,7 +657,7 @@ def _to_java(self): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj( + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d4411fdfb9dd..9dfcef0e40d6 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -99,7 +99,7 @@ def context(self, sqlContext): @inherit_doc class JavaMLWriter(MLWriter): """ - (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """ def __init__(self, instance): @@ -178,7 +178,7 @@ def context(self, sqlContext): @inherit_doc class JavaMLReader(MLReader): """ - (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """ def __init__(self, clazz): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index a2cf2296fbe0..cd0e5b80d555 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -25,29 +25,32 @@ from pyspark.mllib.common import inherit_doc, _java2py, _py2java -@inherit_doc -class JavaWrapper(Params): +class JavaWrapper(object): """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. + Wrapper class for a Java companion object """ + def __init__(self, java_obj=None): + super(JavaWrapper, self).__init__() + self._java_obj = java_obj - __metaclass__ = ABCMeta - - def __init__(self): + @classmethod + def _create_from_java_class(cls, java_class, *args): """ - Initialize the wrapped java object to None + Construct this object from given Java classname and arguments """ - super(JavaWrapper, self).__init__() - #: The wrapped Java companion object. Subclasses should initialize - #: it properly. The param values in the Java object should be - #: synced with the Python wrapper in fit/transform/evaluate/copy. - self._java_obj = None + java_obj = JavaWrapper._new_java_obj(java_class, *args) + return cls(java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) @staticmethod def _new_java_obj(java_class, *args): """ - Construct a new Java object. + Returns a new Java object. """ sc = SparkContext._active_spark_context java_obj = _jvm() @@ -56,6 +59,18 @@ def _new_java_obj(java_class, *args): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + +@inherit_doc +class JavaParams(JavaWrapper, Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + #: The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + + __metaclass__ = ABCMeta + def _make_java_param_pair(self, param, value): """ Makes a Java parm pair. @@ -151,7 +166,7 @@ def __get_class(clazz): stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) - if issubclass(py_type, JavaWrapper): + if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage @@ -166,7 +181,7 @@ def __get_class(clazz): @inherit_doc -class JavaEstimator(Estimator, JavaWrapper): +class JavaEstimator(JavaParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -199,7 +214,7 @@ def _fit(self, dataset): @inherit_doc -class JavaTransformer(Transformer, JavaWrapper): +class JavaTransformer(JavaParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object @@ -213,30 +228,8 @@ def _transform(self, dataset): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) -class JavaCallable(object): - """ - Wrapper for a plain object in JVM to make Java calls, can be used - as a mixin to another class that defines a _java_obj wrapper - """ - def __init__(self, java_obj=None, sc=None): - super(JavaCallable, self).__init__() - self._sc = sc if sc is not None else SparkContext._active_spark_context - # if this class is a mixin and _java_obj is already defined then don't initialize - if java_obj is not None or not hasattr(self, "_java_obj"): - self._java_obj = java_obj - - def __del__(self): - if self._java_obj is not None: - self._sc._gateway.detach(self._java_obj) - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - java_args = [_py2java(self._sc, arg) for arg in args] - return _java2py(self._sc, m(*java_args)) - - @inherit_doc -class JavaModel(Model, JavaCallable, JavaTransformer): +class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -249,7 +242,7 @@ def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, - and then call _transformer_params_from_java. + and then call _transfer_params_from_java. This instance can be instantiated without specifying java_model, it will be assigned after that, but this scenario only used by @@ -259,9 +252,8 @@ def __init__(self, java_model=None): these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """ - super(JavaModel, self).__init__() + super(JavaModel, self).__init__(java_model) if java_model is not None: - self._java_obj = java_model self.uid = java_model.uid() def copy(self, extra=None): diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 612935352575..b3dd2f63a5d8 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -379,6 +379,17 @@ class HashingTF(object): """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + self.binary = False + + @since("2.0.0") + def setBinary(self, value): + """ + If True, term frequency vector will be binary such that non-zero + term counts will be set to 1 + (default: False) + """ + self.binary = value + return self @since('1.2.0') def indexOf(self, term): @@ -398,7 +409,7 @@ def transform(self, document): freq = {} for term in document: i = self.indexOf(term) - freq[i] = freq.get(i, 0) + 1.0 + freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0 return Vectors.sparse(self.numFeatures, freq.items()) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5f515b666ce1..ac55fbf79841 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -58,6 +58,7 @@ from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import HashingTF from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct @@ -1583,6 +1584,21 @@ def test_als_ratings_id_long_error(self): self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4008332c84d0..11dfcfe13ee0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -405,7 +405,7 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - Py4JJavaError:... + Py4JJavaError: ... """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d473d6b53464..b4fa8368936a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -60,7 +60,7 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + people.filter(people.age > 30).join(department, people.deptId == department.id)\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. note:: Experimental diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 176e3bb41c99..ef012d27cb22 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -55,7 +55,7 @@ def __str__(self): StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, False, 2) StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False) StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2) -StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) +StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1) """ .. note:: The following four storage level constants are deprecated in 2.0, since the records \ diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 15c87e22f98b..97ea39dde05f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1914,6 +1914,13 @@ def test_get_or_create(self): with SparkContext.getOrCreate() as sc: self.assertTrue(SparkContext.getOrCreate() is sc) + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index c8b78bc14a39..547da8f713ac 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -285,7 +285,7 @@ class ReplSuite extends SparkFunSuite { val output = runInterpreter("local", """ |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.{Encoder, Encoders} |import org.apache.spark.sql.expressions.Aggregator |import org.apache.spark.sql.TypedColumn |val simpleSum = new Aggregator[Int, Int, Int] { @@ -293,6 +293,8 @@ class ReplSuite extends SparkFunSuite { | def reduce(b: Int, a: Int) = b + a // Add an element to the running total | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt |}.toColumn | |val ds = Seq(1, 2, 3, 4).toDS() @@ -339,30 +341,6 @@ class ReplSuite extends SparkFunSuite { } } - test("Datasets agg type-inference") { - val output = runInterpreter("local", - """ - |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder - |import org.apache.spark.sql.expressions.Aggregator - |import org.apache.spark.sql.TypedColumn - |/** An `Aggregator` that adds up any numeric type returned by the given function. */ - |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { - | val numeric = implicitly[Numeric[N]] - | override def zero: N = numeric.zero - | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) - | override def finish(reduction: N): N = reduction - |} - | - |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn - |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() - |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect() - """.stripMargin) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - } - test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index dbfacba34637..d3dafe9c42ee 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -267,7 +267,7 @@ class ReplSuite extends SparkFunSuite { val output = runInterpreter("local", """ |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder + |import org.apache.spark.sql.{Encoder, Encoders} |import org.apache.spark.sql.expressions.Aggregator |import org.apache.spark.sql.TypedColumn |val simpleSum = new Aggregator[Int, Int, Int] { @@ -275,6 +275,8 @@ class ReplSuite extends SparkFunSuite { | def reduce(b: Int, a: Int) = b + a // Add an element to the running total | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values. | def finish(b: Int) = b // Return the final result. + | def bufferEncoder: Encoder[Int] = Encoders.scalaInt + | def outputEncoder: Encoder[Int] = Encoders.scalaInt |}.toColumn | |val ds = Seq(1, 2, 3, 4).toDS() @@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite { } } - test("Datasets agg type-inference") { - val output = runInterpreter("local", - """ - |import org.apache.spark.sql.functions._ - |import org.apache.spark.sql.Encoder - |import org.apache.spark.sql.expressions.Aggregator - |import org.apache.spark.sql.TypedColumn - |/** An `Aggregator` that adds up any numeric type returned by the given function. */ - |class SumOf[I, N : Numeric](f: I => N) extends - | org.apache.spark.sql.expressions.Aggregator[I, N, N] { - | val numeric = implicitly[Numeric[N]] - | override def zero: N = numeric.zero - | override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2) - | override def finish(reduction: N): N = reduction - |} - | - |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn - |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() - |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect() - """.stripMargin) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - } - test("collecting objects of class defined in repl") { val output = runInterpreter("local[2]", """ @@ -396,4 +373,31 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("should clone and clean line object in ClosureCleaner") { + val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", + """ + |import org.apache.spark.rdd.RDD + | + |val lines = sc.textFile("pom.xml") + |case class Data(s: String) + |val dataRDD = lines.map(line => Data(line.take(3))) + |dataRDD.cache.count + |val repartitioned = dataRDD.repartition(dataRDD.partitions.size) + |repartitioned.cache.count + | + |def getCacheSize(rdd: RDD[_]) = { + | sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum + |} + |val cacheSize1 = getCacheSize(dataRDD) + |val cacheSize2 = getCacheSize(repartitioned) + | + |// The cache size of dataRDD and the repartitioned one should be similar. + |val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1 + |assert(deviation < 0.2, + | s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2") + """.stripMargin) + assertDoesNotContain("AssertionError", output) + assertDoesNotContain("Exception", output) + } } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 928aaa56293b..4a15d52b570a 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -70,26 +70,24 @@ class ExecutorClassLoader( } override def findClass(name: String): Class[_] = { - userClassPathFirst match { - case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) - case false => { - try { - parentLoader.loadClass(name) - } catch { - case e: ClassNotFoundException => { - val classOption = findClassLocally(name) - classOption match { - case None => - // If this class has a cause, it will break the internal assumption of Janino - // (the compiler used for Spark SQL code-gen). - // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see - // its behavior will be changed if there is a cause and the compilation - // of generated class will fail. - throw new ClassNotFoundException(name) - case Some(a) => a - } + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + } else { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) + case Some(a) => a } - } } } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 472a8f40844a..a14e3e583f87 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -228,6 +228,11 @@ This file is divided into 3 sections: Use Javadoc style indentation for multiline comments + + case[^\n>]*=>\s*\{ + Omit braces in case clauses. + + diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 85cb585919da..f5e291ecac3d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -55,6 +55,8 @@ statement rowFormat? createFileFormat? locationSpec? (TBLPROPERTIES tablePropertyList)? (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq?)? #analyze | ALTER (TABLE | VIEW) from=tableIdentifier @@ -104,6 +106,7 @@ statement REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? (PARTITIONED ON identifierList)? @@ -118,6 +121,7 @@ statement | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW CREATE TABLE tableIdentifier #showCreateTable | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? @@ -135,20 +139,16 @@ statement ; hiveNativeCommands - : createTableHeader LIKE tableIdentifier - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - | DELETE FROM tableIdentifier (WHERE booleanExpression)? + : DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? - | DROP VIEW (IF EXISTS)? qualifiedName | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? | START TRANSACTION (transactionMode (',' transactionMode)*)? | COMMIT WORK? | ROLLBACK WORK? | SHOW PARTITIONS tableIdentifier partitionSpec? | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | MSCK | LOAD) .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOAD) .*? ; unsupportedHiveNativeCommands @@ -177,6 +177,7 @@ unsupportedHiveNativeCommands | kw1=UNLOCK kw2=DATABASE | kw1=CREATE kw2=TEMPORARY kw3=MACRO | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=MSCK kw2=REPAIR kw3=TABLE ; createTableHeader @@ -271,8 +272,7 @@ createFileFormat ; fileFormat - : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? - (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat | identifier #genericFileFormat ; @@ -651,7 +651,7 @@ nonReserved | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION ; @@ -867,6 +867,7 @@ GRANT: 'GRANT'; LOCK: 'LOCK'; UNLOCK: 'UNLOCK'; MSCK: 'MSCK'; +REPAIR: 'REPAIR'; EXPORT: 'EXPORT'; IMPORT: 'IMPORT'; LOAD: 'LOAD'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3555a6d7faef..de40ddde1bdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -409,7 +409,7 @@ class Analyzer( catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"Table not found: ${u.tableName}") + u.failAnalysis(s"Table or View not found: ${u.tableName}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 488050239806..d6a8c3eec81a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -52,7 +52,7 @@ trait CheckAnalysis { case p if p.analyzed => // Skip already analyzed sub-plans case u: UnresolvedRelation => - u.failAnalysis(s"Table not found: ${u.tableIdentifier}") + u.failAnalysis(s"Table or View not found: ${u.tableIdentifier}") case operator: LogicalPlan => operator transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f239b33e44ee..f2abf136da68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -171,6 +171,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CaseWhen]("when"), // math functions expression[Acos]("acos"), @@ -217,6 +218,12 @@ object FunctionRegistry { expression[Tan]("tan"), expression[Tanh]("tanh"), + expression[Add]("+"), + expression[Subtract]("-"), + expression[Multiply]("*"), + expression[Divide]("/"), + expression[Remainder]("%"), + // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), @@ -257,6 +264,7 @@ object FunctionRegistry { expression[Lower]("lcase"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[Like]("like"), expression[Lower]("lower"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), @@ -267,6 +275,7 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), + expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), @@ -343,7 +352,29 @@ object FunctionRegistry { expression[NTile]("ntile"), expression[Rank]("rank"), expression[DenseRank]("dense_rank"), - expression[PercentRank]("percent_rank") + expression[PercentRank]("percent_rank"), + + // predicates + expression[And]("and"), + expression[In]("in"), + expression[Not]("not"), + expression[Or]("or"), + + expression[EqualNullSafe]("<=>"), + expression[EqualTo]("="), + expression[EqualTo]("=="), + expression[GreaterThan](">"), + expression[GreaterThanOrEqual](">="), + expression[LessThan]("<"), + expression[LessThanOrEqual]("<="), + expression[Not]("!"), + + // bitwise + expression[BitwiseAnd]("&"), + expression[BitwiseNot]("~"), + expression[BitwiseOr]("|"), + expression[BitwiseXor]("^") + ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index e9f04eecf8d7..5e18316c94bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec @@ -24,29 +25,13 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -abstract class NoSuchItemException extends Exception { - override def getMessage: String -} +class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found") -class NoSuchDatabaseException(db: String) extends NoSuchItemException { - override def getMessage: String = s"Database $db not found" -} +class NoSuchTableException(db: String, table: String) + extends AnalysisException(s"Table or View $table not found in database $db") -class NoSuchTableException(db: String, table: String) extends NoSuchItemException { - override def getMessage: String = s"Table $table not found in database $db" -} +class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends + AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n")) -class NoSuchPartitionException( - db: String, - table: String, - spec: TablePartitionSpec) - extends NoSuchItemException { - - override def getMessage: String = { - s"Partition not found in table $table database $db:\n" + spec.mkString("\n") - } -} - -class NoSuchFunctionException(db: String, func: String) extends NoSuchItemException { - override def getMessage: String = s"Function $func not found in database $db" -} +class NoSuchFunctionException(db: String, func: String) + extends AnalysisException(s"Function $func not found in database $db") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 186bbccef120..569614476f61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -62,7 +62,7 @@ class InMemoryCatalog extends ExternalCatalog { private def requireTableExists(db: String, table: String): Unit = { if (!tableExists(db, table)) { throw new AnalysisException( - s"Table not found: '$table' does not exist in database '$db'") + s"Table or View not found: '$table' does not exist in database '$db'") } } @@ -164,7 +164,7 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables.remove(table) } else { if (!ignoreIfNotExists) { - throw new AnalysisException(s"Table '$table' does not exist in database '$db'") + throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'") } } } @@ -187,6 +187,10 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables(table).table } + override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized { + if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table) + } + override def tableExists(db: String, table: String): Boolean = synchronized { requireDbExists(db) catalog(db).tables.contains(table) @@ -201,6 +205,11 @@ class InMemoryCatalog extends ExternalCatalog { StringUtils.filterPattern(listTables(db), pattern) } + override def showCreateTable(db: String, table: String): String = { + throw new AnalysisException( + "SHOW CREATE TABLE command is not supported for temporary tables created in SQLContext.") + } + // -------------------------------------------------------------------------- // Partitions // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7db9fd0527ec..980dda986338 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -21,6 +21,7 @@ import java.io.File import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -41,7 +42,7 @@ class SessionCatalog( externalCatalog: ExternalCatalog, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, - conf: CatalystConf) { + conf: CatalystConf) extends Logging { import ExternalCatalog._ def this( @@ -175,6 +176,26 @@ class SessionCatalog( externalCatalog.getTable(db, table) } + /** + * Retrieve the metadata of an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then return None if it doesn't exist. + */ + def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.getTableOption(db, table) + } + + /** + * Generate Create table DDL string for the specified tableIdentifier + */ + def showCreateTable(name: TableIdentifier): String = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.showCreateTable(db, table) + } + // ------------------------------------------------------------- // | Methods that interact with temporary and metastore tables | // ------------------------------------------------------------- @@ -229,7 +250,13 @@ class SessionCatalog( val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { - externalCatalog.dropTable(db, table, ignoreIfNotExists) + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. + if (externalCatalog.tableExists(db, table)) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true) + } else if (!ignoreIfNotExists) { + logError(s"Table or View '${name.quotedString}' does not exist") + } } else { tempTables.remove(table) } @@ -283,7 +310,7 @@ class SessionCatalog( * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = { - !name.database.isDefined && tempTables.contains(formatTableName(name.table)) + name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e29d6bd8b09e..f20699a58cf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -91,12 +91,16 @@ abstract class ExternalCatalog { def getTable(db: String, table: String): CatalogTable + def getTableOption(db: String, table: String): Option[CatalogTable] + def tableExists(db: String, table: String): Boolean def listTables(db: String): Seq[String] def listTables(db: String, pattern: String): Seq[String] + def showCreateTable(db: String, table: String): String + // -------------------------------------------------------------------------- // Partitions // -------------------------------------------------------------------------- @@ -218,14 +222,30 @@ case class CatalogTable( tableType: CatalogTableType, storage: CatalogStorageFormat, schema: Seq[CatalogColumn], - partitionColumns: Seq[CatalogColumn] = Seq.empty, - sortColumns: Seq[CatalogColumn] = Seq.empty, - numBuckets: Int = 0, + partitionColumnNames: Seq[String] = Seq.empty, + sortColumnNames: Seq[String] = Seq.empty, + bucketColumnNames: Seq[String] = Seq.empty, + numBuckets: Int = -1, createTime: Long = System.currentTimeMillis, - lastAccessTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, properties: Map[String, String] = Map.empty, viewOriginalText: Option[String] = None, - viewText: Option[String] = None) { + viewText: Option[String] = None, + comment: Option[String] = None) { + + // Verify that the provided columns are part of the schema + private val colNames = schema.map(_.name).toSet + private def requireSubsetOfSchema(cols: Seq[String], colType: String): Unit = { + require(cols.toSet.subsetOf(colNames), s"$colType columns (${cols.mkString(", ")}) " + + s"must be a subset of schema (${colNames.mkString(", ")}) in table '$identifier'") + } + requireSubsetOfSchema(partitionColumnNames, "partition") + requireSubsetOfSchema(sortColumnNames, "sort") + requireSubsetOfSchema(bucketColumnNames, "bucket") + + /** Columns this table is partitioned by. */ + def partitionColumns: Seq[CatalogColumn] = + schema.filter { c => partitionColumnNames.contains(c.name) } /** Return the database this table was specified to belong to, assuming it exists. */ def database: String = identifier.database.getOrElse { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d842ffdc6637..0f8876a9e688 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -898,7 +898,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val result = ctx.freshName("result") val tmpRow = ctx.freshName("tmpRow") - val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") @@ -920,7 +920,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } """ - } }.mkString("\n") (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index affd1bdb327c..8d8cc152ff29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -97,11 +97,11 @@ class EquivalentExpressions { def debugString(all: Boolean = false): String = { val sb: mutable.StringBuilder = new StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.foreach { case (k, v) => { + equivalenceMap.foreach { case (k, v) => if (all || v.length > 1) { sb.append(" " + v.mkString(", ")).append("\n") } - }} + } sb.toString() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index dbd0acf06caa..2ed6fc0d3824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDDState +import org.apache.spark.rdd.InputFileNameHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** - * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + * Expression that returns the name of the current file being read. */ @ExpressionDescription( usage = "_FUNC_() - Returns the name of the current file being read if available", @@ -40,12 +40,12 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDDState.getInputFileName() + InputFileNameHolder.getInputFileName() } override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4615c55d676f..61ca7272dfa6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.types._ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any - def update(v: Any) + def update(v: Any): Unit def copy(): MutableValue } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 94ac4bf09b90..ff7077484783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the mean calculated from values of a group.") case class Average(child: Expression) extends DeclarativeAggregate { override def prettyName: String = "avg" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 9d2db4514481..17a7c6dce89c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -130,6 +130,10 @@ abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate } // Compute the population standard deviation of a column +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population standard deviation calculated from values of a group.") +// scalastyle:on line.size.limit case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -143,6 +147,8 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { } // Compute the sample standard deviation of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample standard deviation calculated from values of a group.") case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -157,6 +163,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { } // Compute the population variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population variance calculated from values of a group.") case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -170,6 +178,8 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { } // Compute the sample variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample variance calculated from values of a group.") case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -183,6 +193,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "var_samp" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Skewness value calculated from values of a group.") case class Skewness(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "skewness" @@ -196,6 +208,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Kurtosis value calculated from values of a group.") case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index e6b8214ef25e..e29265e2f41e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns Pearson coefficient of correlation between a set of number pairs.") case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = Seq(x, y) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 663c69e799fb..17ae012af79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(*) - Returns the total number of retrieved rows, including rows containing NULL values. + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-NULL. + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-NULL.""") +// scalastyle:on line.size.limit case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index c175a8c4c77b..d80afbebf740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -76,6 +76,8 @@ abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggre } } +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), @@ -85,6 +87,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance } +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 35f57426feaf..b8ab0364dd8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -28,6 +28,11 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the first value of `child` for a group of rows. + _FUNC_(expr,isIgnoreNull=false) - Returns the first value of `child` for a group of rows. + If isIgnoreNull is true, returns only non-null values. + """) case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index b6bd56cff6b3..1d218da6db80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.lang.{Long => JLong} import java.util -import com.clearspring.analytics.hash.MurmurHash - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -48,6 +46,11 @@ import org.apache.spark.sql.types._ * @param relativeSD the maximum estimation error allowed. */ // scalastyle:on +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the estimated cardinality by HyperLogLog++. + _FUNC_(expr, relativeSD=0.05) - Returns the estimated cardinality by HyperLogLog++ + with relativeSD, the maximum estimation error allowed. + """) case class HyperLogLogPlusPlus( child: Expression, relativeSD: Double = 0.05, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index be7e12d7a233..b05d74b49b59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = "_FUNC_(expr,isIgnoreNull) - Returns the last value of `child` for a group of rows.") case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 906003188d4f..c534fe495fc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the maximum value of expr.") case class Max(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 39f7afbd081c..35289b468183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the minimum value of expr.") case class Min(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 08a67ea3df51..ad217f25b5a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sum calculated from values of a group.") case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b38809153882..f3d42fc0b216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval - +@ExpressionDescription( + usage = "_FUNC_(a) - Returns -a.") case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { @@ -59,6 +60,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression override def sql: String = s"(-${child.sql})" } +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a.") case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" @@ -79,8 +82,8 @@ case class UnaryPositive(child: Expression) * A function that get the absolute value of the numeric value. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", - extended = "> SELECT _FUNC_('-1');\n1") + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.", + extended = "> SELECT _FUNC_('-1');\n 1") case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { @@ -126,6 +129,8 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a+b.") case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -155,6 +160,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit } } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a-b.") case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -185,6 +192,8 @@ case class Subtract(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Multiplies a by b.") case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -198,6 +207,9 @@ case class Multiply(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } +@ExpressionDescription( + usage = "a _FUNC_ b - Divides a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1.5") case class Divide(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -275,6 +287,8 @@ case class Divide(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns the remainder when dividing a by b.") case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -464,6 +478,9 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns the positive modulo", + extended = "> SELECT _FUNC_(10,3);\n 1") case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 4c90b3f7d33a..a7e1cd66f24a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.types._ * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise AND.", + extended = "> SELECT 3 _FUNC_ 5; 1") case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -51,6 +54,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise OR.", + extended = "> SELECT 3 _FUNC_ 5; 7") case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -76,6 +82,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise exclusive OR.", + extended = "> SELECT 3 _FUNC_ 5; 2") case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -99,6 +108,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ +@ExpressionDescription( + usage = "_FUNC_ b - Bitwise NOT.", + extended = "> SELECT _FUNC_ 0; -1") case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ee7f4fadca89..f43626ca814a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -519,7 +519,7 @@ class CodegenContext { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) - commonExprs.foreach(e => { + commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") val isNull = s"${fnName}IsNull" @@ -561,7 +561,7 @@ class CodegenContext { subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) - }) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e36c9852491b..ab790cf372d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.types._ /** * Given an array or map, returns its size. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the size of an array or a map.") case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) @@ -44,6 +46,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.", + extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'") +// scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -125,6 +132,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) /** * Checks if the array (left) has the element (right) */ +@ExpressionDescription( + usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.", + extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c299586ddefe..74de4a776de8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -27,6 +27,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(n0, ...) - Returns an array with the given elements.") case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -73,6 +75,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a catalyst Map containing the evaluation of all children expressions as keys and values. * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...) */ +@ExpressionDescription( + usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.") case class CreateMap(children: Seq[Expression]) extends Expression { private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) @@ -153,6 +157,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { /** * Returns a Row containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -204,6 +210,10 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 35a7b4602074..ae6a94842f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -23,7 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1,expr2,expr3) - If expr1 is TRUE then IF() returns expr2; otherwise it returns expr3.") +// scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) extends Expression { @@ -85,6 +88,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") +// scalastyle:on line.size.limit case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) extends Expression with CodegenFallback { @@ -256,6 +263,8 @@ object CaseKeyWhen { * A function that returns the least value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the least value of all parameters, skipping null values.") case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) @@ -315,6 +324,8 @@ case class Least(children: Seq[Expression]) extends Expression { * A function that returns the greatest value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the greatest value of all parameters, skipping null values.") case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1d0ea68d7a7b..9135753041f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -35,6 +35,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current date at the start of query evaluation.") case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -54,6 +56,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current timestamp at the start of query evaluation.") case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -70,6 +74,9 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { /** * Adds a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'") case class DateAdd(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -96,6 +103,9 @@ case class DateAdd(startDate: Expression, days: Expression) /** * Subtracts a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'") case class DateSub(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = startDate @@ -118,6 +128,9 @@ case class DateSub(startDate: Expression, days: Expression) override def prettyName: String = "date_sub" } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the hour component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 12") case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -134,6 +147,9 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the minute component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 58") case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -150,6 +166,9 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the second component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 59") case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -166,6 +185,9 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of year of date/timestamp.", + extended = "> SELECT _FUNC_('2016-04-09');\n 100") case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -182,7 +204,9 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } } - +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the year component of the date/timestamp/interval.", + extended = "> SELECT _FUNC_('2016-07-30');\n 2016") case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -199,6 +223,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the quarter of the year for date, in the range 1 to 4.") case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -215,6 +241,9 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the month component of the date/timestamp/interval", + extended = "> SELECT _FUNC_('2016-07-30');\n 7") case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -231,6 +260,9 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of month of date/timestamp, or the day of interval.", + extended = "> SELECT _FUNC_('2009-07-30');\n 30") case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -247,6 +279,9 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the week of the year of the given date.", + extended = "> SELECT _FUNC_('2008-02-20');\n 8") case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -283,6 +318,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date/timestamp/string, fmt) - Converts a date/timestamp/string to a value of string in the format specified by the date format fmt.", + extended = "> SELECT _FUNC_('2016-04-08', 'y')\n '2016'") +// scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -310,6 +350,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx * Converts time string with given pattern. * Deterministic version of [[UnixTimestamp]], must have at least one parameter. */ +@ExpressionDescription( + usage = "_FUNC_(date[, pattern]) - Returns the UNIX timestamp of the give time.") case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -331,6 +373,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix * If the first parameter is a Date or Timestamp instead of String, we will ignore the * second parameter. */ +@ExpressionDescription( + usage = "_FUNC_([date[, pattern]]) - Returns the UNIX timestamp of current or specified time.") case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -459,6 +503,9 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { * format. If the format is missing, using format like "1970-01-01 00:00:00". * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. */ +@ExpressionDescription( + usage = "_FUNC_(unix_time, format) - Returns unix_time in the specified format", + extended = "> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');\n '1970-01-01 00:00:00'") case class FromUnixTime(sec: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -544,6 +591,9 @@ case class FromUnixTime(sec: Expression, format: Expression) /** * Returns the last day of the month which the date belongs to. */ +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the last day of the month which the date belongs to.", + extended = "> SELECT _FUNC_('2009-01-12');\n '2009-01-31'") case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def child: Expression = startDate @@ -570,6 +620,11 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC * * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(start_date, day_of_week) - Returns the first date which is later than start_date and named as indicated.", + extended = "> SELECT _FUNC_('2015-01-14', 'TU');\n '2015-01-20'") +// scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -654,6 +709,10 @@ case class TimeAdd(start: Expression, interval: Expression) /** * Assumes given timestamp is UTC and converts to given timezone. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is UTC and converts to given timezone.") +// scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -729,6 +788,9 @@ case class TimeSub(start: Expression, interval: Expression) /** * Returns the date that is num_months after start_date. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.", + extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'") case class AddMonths(startDate: Expression, numMonths: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -756,6 +818,9 @@ case class AddMonths(startDate: Expression, numMonths: Expression) /** * Returns number of months between dates date1 and date2. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2.", + extended = "> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677") case class MonthsBetween(date1: Expression, date2: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -783,6 +848,10 @@ case class MonthsBetween(date1: Expression, date2: Expression) /** * Assumes given timestamp is in given timezone and converts to UTC. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is in given timezone and converts to UTC.") +// scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -830,6 +899,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) /** * Returns the date part of a timestamp or string. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Extracts the date part of the date or datetime expression expr.", + extended = "> SELECT _FUNC_('2009-07-30 04:17:52');\n '2009-07-30'") case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // Implicit casting of spark will accept string in both date and timestamp format, as @@ -850,6 +922,11 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Returns date truncated to the unit specified by the format. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.", + extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'") +// scalastyle:on line.size.limit case class TruncDate(date: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = date @@ -921,6 +998,9 @@ case class TruncDate(date: Expression, format: Expression) /** * Returns the number of days from startDate to endDate. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - Returns the number of days between date1 and date2.", + extended = "> SELECT _FUNC_('2009-07-30', '2009-07-31');\n 1") case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e7ef21aa8589..65d7a1d5a090 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -99,6 +99,10 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.") +// scalastyle:on line.size.limit case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 72b323587c63..ecd09b7083f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -106,6 +106,8 @@ private[this] object SharedFactory { * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. */ +@ExpressionDescription( + usage = "_FUNC_(json_txt, path) - Extract a json object from path") case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -319,6 +321,10 @@ case class GetJsonObject(json: Expression, path: Expression) } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - like get_json_object, but it takes multiple names and return a tuple. All the input parameters and output column types are string.") +// scalastyle:on line.size.limit case class JsonTuple(children: Seq[Expression]) extends Generator with CodegenFallback { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e3d1bc127d2e..c8a28e847745 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -50,6 +50,7 @@ abstract class LeafMathExpression(c: Double, name: String) /** * A unary expression specifically for math functions. Math Functions expect a specific type of * input format, therefore these functions extend `ExpectsInputTypes`. + * * @param f The math function. * @param name The short name of the function */ @@ -103,6 +104,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) /** * A binary expression specifically for math functions that take two `Double`s as input and returns * a `Double`. + * * @param f The math function. * @param name The short name of the function */ @@ -136,12 +138,18 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) * Euler's number. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns Euler's number, E.", + extended = "> SELECT _FUNC_();\n 2.718281828459045") case class EulerNumber() extends LeafMathExpression(math.E, "E") /** * Pi. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns PI.", + extended = "> SELECT _FUNC_();\n 3.141592653589793") case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -150,14 +158,29 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc cosine of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(1);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc sin of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(0);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc tangent.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cube root of a double value.", + extended = "> SELECT _FUNC_(27.0);\n 3.0") case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the smallest integer not smaller than x.", + extended = "> SELECT _FUNC_(-0.1);\n 0\n> SELECT _FUNC_(5);\n 5") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -184,16 +207,26 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") /** * Convert a num from one base to another + * * @param numExpr the number to be converted * @param fromBaseExpr from which base * @param toBaseExpr to which base */ +@ExpressionDescription( + usage = "_FUNC_(num, from_base, to_base) - Convert num from from_base to to_base.", + extended = "> SELECT _FUNC_('100', 2, 10);\n '4'\n> SELECT _FUNC_(-10, 16, -10);\n '16'") case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -222,10 +255,19 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns e to the power of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns exp(x) - 1.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the largest integer not greater than x.", + extended = "> SELECT _FUNC_(-0.1);\n -1\n> SELECT _FUNC_(5);\n 5") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -283,6 +325,9 @@ object Factorial { ) } +@ExpressionDescription( + usage = "_FUNC_(n) - Returns n factorial for n is [0..20]. Otherwise, NULL.", + extended = "> SELECT _FUNC_(5);\n 120") case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -315,8 +360,14 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the natural logarithm of x with base e.", + extended = "> SELECT _FUNC_(1);\n 0.0") case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 2.", + extended = "> SELECT _FUNC_(2);\n 1.0") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodegenContext, ev: ExprCode): String = { @@ -332,36 +383,72 @@ case class Log2(child: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 10.", + extended = "> SELECT _FUNC_(10);\n 1.0") case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns log(1 + x).", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { protected override val yAsymptote: Double = -1.0 } +@ExpressionDescription( + usage = "_FUNC_(x, d) - Return the rounded x at d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sign of x.", + extended = "> SELECT _FUNC_(40);\n 1.0") case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the square root of x.", + extended = "> SELECT _FUNC_(4);\n 2.0") case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +@ExpressionDescription( + usage = "_FUNC_(x) - Converts radians to degrees.", + extended = "> SELECT _FUNC_(3.141592653589793);\n 180.0") case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { override def funcName: String = "toDegrees" } +@ExpressionDescription( + usage = "_FUNC_(x) - Converts degrees to radians.", + extended = "> SELECT _FUNC_(180);\n 3.141592653589793") case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns x in binary.", + extended = "> SELECT _FUNC_(13);\n '1101'") case class Bin(child: Expression) extends UnaryExpression with Serializable with ImplicitCastInputTypes { @@ -453,6 +540,9 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Convert the argument to hexadecimal.", + extended = "> SELECT _FUNC_(17);\n '11'\n> SELECT _FUNC_('Spark SQL');\n '537061726B2053514C'") case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = @@ -481,6 +571,9 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Converts hexadecimal argument to binary.", + extended = "> SELECT decode(_FUNC_('537061726B2053514C'),'UTF-8');\n 'Spark SQL'") case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -509,7 +602,9 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// - +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the arc tangent2.", + extended = "> SELECT _FUNC_(0, 0);\n 0.0") case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { @@ -523,6 +618,9 @@ case class Atan2(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x1, x2) - Raise x1 to the power of x2.", + extended = "> SELECT _FUNC_(2, 3);\n 8.0") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodegenContext, ev: ExprCode): String = { @@ -532,10 +630,14 @@ case class Pow(left: Expression, right: Expression) /** - * Bitwise unsigned left shift. + * Bitwise left shift. + * * @param left the base number to shift. * @param right number of bits to left shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise left shift.", + extended = "> SELECT _FUNC_(2, 1);\n 4") case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -558,10 +660,14 @@ case class ShiftLeft(left: Expression, right: Expression) /** - * Bitwise unsigned left shift. + * Bitwise right shift. + * * @param left the base number to shift. - * @param right number of bits to left shift. + * @param right number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -585,9 +691,13 @@ case class ShiftRight(left: Expression, right: Expression) /** * Bitwise unsigned right shift, for integer and long data type. + * * @param left the base number. * @param right the number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise unsigned right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -608,16 +718,22 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns sqrt(a**2 + b**2).", + extended = "> SELECT _FUNC_(3, 4);\n 5.0") case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") /** * Computes the logarithm of a number. + * * @param left the logarithm base, default to e. * @param right the number to compute the logarithm of. */ +@ExpressionDescription( + usage = "_FUNC_(b, x) - Returns the logarithm of x with base b.", + extended = "> SELECT _FUNC_(10, 100);\n 2.0") case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { @@ -674,6 +790,9 @@ case class Logarithm(left: Expression, right: Expression) * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Round(child: Expression, scale: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index eb8dc1423afb..4bd918ed01ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -438,6 +438,8 @@ abstract class InterpretedHashFunction { * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle * and bucketing have same data distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.") case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { def this(arguments: Seq[Expression]) = this(arguments, 42) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index e22026d58465..6a452499430c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -34,6 +34,9 @@ import org.apache.spark.sql.types._ * coalesce(null, null, null) => null * }}} */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns the first non-null argument if exists. Otherwise, NULL.", + extended = "> SELECT _FUNC_(NULL, 1, NULL);\n 1") case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -89,6 +92,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** * Evaluates to `true` iff it's NaN. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NaN and false otherwise.") case class IsNaN(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes { @@ -126,6 +131,8 @@ case class IsNaN(child: Expression) extends UnaryExpression * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. * This Expression is useful for mapping NaN values to null. */ +@ExpressionDescription( + usage = "_FUNC_(a,b) - Returns a iff it's not NaN, or b otherwise.") case class NaNvl(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -180,6 +187,8 @@ case class NaNvl(left: Expression, right: Expression) /** * An expression that is evaluated to true if the input is null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NULL and false otherwise.") case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -201,6 +210,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { /** * An expression that is evaluated to true if the input is not null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is not NULL and false otherwise.") case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 28b6b2adf80a..26b1ff39b3e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -446,6 +446,8 @@ case class MapObjects private( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) + ctx.addMutableState("boolean", loopVar.isNull, "") + ctx.addMutableState(elementJavaType, loopVar.value, "") val genInputData = inputData.gen(ctx) val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") @@ -466,9 +468,9 @@ case class MapObjects private( } val loopNullCheck = if (primitiveElement) { - s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -484,7 +486,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType ${loopVar.value} = + ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4eb33258ac04..38f1210a4edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -88,7 +88,8 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - +@ExpressionDescription( + usage = "_FUNC_ a - Logical not") case class Not(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { @@ -109,6 +110,8 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ +@ExpressionDescription( + usage = "expr _FUNC_(val1, val2, ...) - Returns true if expr equals to any valN.") case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { @@ -243,6 +246,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } +@ExpressionDescription( + usage = "a _FUNC_ b - Logical AND.") case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType @@ -306,7 +311,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Logical OR.") case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType @@ -401,7 +407,8 @@ private[sql] object Equality { } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a equals b and false otherwise.") case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -426,7 +433,9 @@ case class EqualTo(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = """a _FUNC_ b - Returns same result with EQUAL(=) operator for non-null operands, + but returns TRUE if both are NULL, FALSE if one of the them is NULL.""") case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def inputType: AbstractDataType = AnyDataType @@ -467,7 +476,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is less than b.") case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -480,7 +490,8 @@ case class LessThan(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not greater than b.") case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -493,7 +504,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is greater than b.") case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -506,7 +518,8 @@ case class GreaterThan(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not smaller than b.") case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6be3cbcae629..1ec092a5be96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -55,6 +55,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1).") case class Rand(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -78,6 +80,8 @@ case class Rand(seed: Long) extends RDG { } /** Generate a random column with i.i.d. gaussian random distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.") case class Randn(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b68009331b0a..85a54292639d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -67,6 +67,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { /** * Simple RegEx pattern matching function */ +@ExpressionDescription( + usage = "str _FUNC_ pattern - Returns true if str matches pattern and false otherwise.") case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { @@ -117,7 +119,8 @@ case class Like(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "str _FUNC_ regexp - Returns true if str matches regexp and false otherwise.") case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { @@ -169,6 +172,9 @@ case class RLike(left: Expression, right: Expression) /** * Splits str around pat (pattern is a regular expression). */ +@ExpressionDescription( + usage = "_FUNC_(str, regex) - Splits str around occurrences that match regex", + extended = "> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');\n ['one', 'two', 'three']") case class StringSplit(str: Expression, pattern: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -198,6 +204,9 @@ case class StringSplit(str: Expression, pattern: Expression) * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp, rep) - replace all substrings of str that match regexp with rep.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)', 'num');\n 'num-num'") case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -289,6 +298,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp[, idx]) - extracts a group that matches regexp.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);\n '100'") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index be6b2530ef39..93a827852869 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -164,7 +164,7 @@ trait BaseGenericInternalRow extends InternalRow { abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(i: Int, value: Any) + def update(i: Int, value: Any): Unit // default implementation (slow) def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7e0e7a833b0d..a17482697d90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -35,6 +35,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN", + extended = "> SELECT _FUNC_('Spark','SQL');\n 'SparkSQL'") case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -70,6 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ +@ExpressionDescription( + usage = + "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by sep.", + extended = "> SELECT _FUNC_(' ', Spark', 'SQL');\n 'Spark SQL'") case class ConcatWs(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -188,7 +195,7 @@ case class Upper(child: Expression) */ @ExpressionDescription( usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", - extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") + extended = "> SELECT _FUNC_('SparkSql');\n 'sparksql'") case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -270,6 +277,11 @@ object StringTranslate { * The translate will happen when any character in the string matching with the character * in the `matchingExpr`. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(input, from, to) - Translates the input string by replacing the characters present in the from string with the corresponding characters in the to string""", + extended = "> SELECT _FUNC_('AaBbCc', 'abc', '123');\n 'A1B2C3'") +// scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -325,6 +337,12 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac * delimited list (right). Returns 0, if the string wasn't found or if the given * string (left) contains a comma. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, str_array) - Returns the index (1-based) of the given string (left) in the comma-delimited list (right). + Returns 0, if the string wasn't found or if the given string (left) contains a comma.""", + extended = "> SELECT _FUNC_('ab','abc,b,ab,c,def');\n 3") +// scalastyle:on case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -347,6 +365,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi /** * A function that trim the spaces from both ends for the specified string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading and trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL'") case class StringTrim(child: Expression) extends UnaryExpression with String2StringExpression { @@ -362,6 +383,9 @@ case class StringTrim(child: Expression) /** * A function that trim the spaces from left end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL '") case class StringTrimLeft(child: Expression) extends UnaryExpression with String2StringExpression { @@ -377,6 +401,9 @@ case class StringTrimLeft(child: Expression) /** * A function that trim the spaces from right end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n ' SparkSQL'") case class StringTrimRight(child: Expression) extends UnaryExpression with String2StringExpression { @@ -396,6 +423,9 @@ case class StringTrimRight(child: Expression) * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +@ExpressionDescription( + usage = "_FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of substr in str.", + extended = "> SELECT _FUNC_('SparkSQL', 'SQL');\n 6") case class StringInstr(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -422,6 +452,15 @@ case class StringInstr(str: Expression, substr: Expression) * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, delim, count) - Returns the substring from str before count occurrences of the delimiter delim. + If count is positive, everything to the left of the final delimiter (counting from the + left) is returned. If count is negative, everything to the right of the final delimiter + (counting from the right) is returned. Substring_index performs a case-sensitive match + when searching for delim.""", + extended = "> SELECT _FUNC_('www.apache.org', '.', 2);\n 'www.apache'") +// scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -445,6 +484,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * A function that returns the position of the first occurrence of substr * in given string after position pos. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos. + The given pos and return value are 1-based.""", + extended = "> SELECT _FUNC_('bar', 'foobarbar', 5);\n 7") +// scalastyle:on line.size.limit case class StringLocate(substr: Expression, str: Expression, start: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -510,6 +555,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) /** * Returns str, left-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, left-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n '???hi'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringLPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -531,6 +581,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) /** * Returns str, right-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, right-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n 'hi???'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringRPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -552,6 +607,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(String format, Obj... args) - Returns a formatted string from printf-style format strings.", + extended = "> SELECT _FUNC_(\"Hello World %d %s\", 100, \"days\");\n 'Hello World 100 days'") +// scalastyle:on line.size.limit case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "format_string() should take at least 1 argument") @@ -642,6 +702,9 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI /** * Returns the string which repeat the given string value n times. */ +@ExpressionDescription( + usage = "_FUNC_(str, n) - Returns the string which repeat the given string value n times.", + extended = "> SELECT _FUNC_('123', 2);\n '123123'") case class StringRepeat(str: Expression, times: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -664,6 +727,9 @@ case class StringRepeat(str: Expression, times: Expression) /** * Returns the reversed given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the reversed given string.", + extended = "> SELECT _FUNC_('Spark SQL');\n 'LQS krapS'") case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.reverse() @@ -677,6 +743,9 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ +@ExpressionDescription( + usage = "_FUNC_(n) - Returns a n spaces string.", + extended = "> SELECT _FUNC_(2);\n ' '") case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -699,7 +768,14 @@ case class StringSpace(child: Expression) /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str, pos[, len]) - Returns the substring of str that starts at pos and is of length len or the slice of byte array that starts at pos and is of length len.", + extended = "> SELECT _FUNC_('Spark SQL', 5);\n 'k SQL'\n> SELECT _FUNC_('Spark SQL', -3);\n 'SQL'\n> SELECT _FUNC_('Spark SQL', 5, 1);\n 'k'") +// scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -737,6 +813,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string or binary expression. */ +@ExpressionDescription( + usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.", + extended = "> SELECT _FUNC_('Spark SQL');\n 9") case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -757,6 +836,9 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy /** * A function that return the Levenshtein distance between the two given strings. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2) - Returns the Levenshtein distance between the two given strings.", + extended = "> SELECT _FUNC_('kitten', 'sitting');\n 3") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -775,6 +857,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * A function that return soundex code of the given string expression. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns soundex code of the string.", + extended = "> SELECT _FUNC_('Miller');\n 'M460'") case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -791,6 +876,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT /** * Returns the numeric value of the first character of str. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the numeric value of the first character of str.", + extended = "> SELECT _FUNC_('222');\n 50\n" + + "> SELECT _FUNC_(2);\n 50") case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType @@ -822,6 +911,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp /** * Converts the argument from binary to a base 64 string. */ +@ExpressionDescription( + usage = "_FUNC_(bin) - Convert the argument from binary to a base 64 string.") case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -844,6 +935,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Converts the argument from a base 64 string to BINARY. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Convert the argument from a base 64 string to binary.") case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType @@ -865,6 +958,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. */ +@ExpressionDescription( + usage = "_FUNC_(bin, str) - Decode the first argument using the second argument character set.") case class Decode(bin: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -894,7 +989,9 @@ case class Decode(bin: Expression, charset: Expression) * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. -*/ + */ +@ExpressionDescription( + usage = "_FUNC_(str, str) - Encode the first argument using the second argument character set.") case class Encode(value: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -924,6 +1021,11 @@ case class Encode(value: Expression, charset: Expression) * and returns the result as a string. If D is 0, the result has no decimal point or * fractional part. */ +@ExpressionDescription( + usage = """_FUNC_(X, D) - Formats the number X like '#,###,###.##', rounded to D decimal places. + If D is 0, the result has no decimal point or fractional part. + This is supposed to function like MySQL's FORMAT.""", + extended = "> SELECT _FUNC_(12332.123456, 4);\n '12,332.1235'") case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 619514e8aacb..f5172b213a74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ReorderJoin, OuterJoinElimination, PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, + PushDownPredicate, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, @@ -86,6 +84,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyConditionals, RemoveDispensableExpressions, + BinaryComparisonSimplification, PruneFilters, EliminateSorts, SimplifyCasts, @@ -287,10 +286,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { val newFirstChild = Project(projectList, children.head) - val newOtherChildren = children.tail.map ( child => { + val newOtherChildren = children.tail.map { child => val rewrites = buildRewrites(children.head, child) Project(projectList.map(pushToRight(_, rewrites)), child) - } ) + } Union(newFirstChild +: newOtherChildren) } else { p @@ -518,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] { // Cases like "something\%" are not optimized, but this does not affect correctness. private val startsWith = "([^_%]+)%".r private val endsWith = "%([^_%]+)".r + private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r private val contains = "%([^_%]+)%".r private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(utf, StringType)) => - utf.toString match { - case startsWith(pattern) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case endsWith(pattern) => - EndsWith(l, Literal(pattern)) - case contains(pattern) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case equalTo(pattern) => - EqualTo(l, Literal(pattern)) + case Like(input, Literal(pattern, StringType)) => + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) case _ => - Like(l, Literal.create(utf, StringType)) + Like(input, Literal.create(pattern, StringType)) } } } @@ -786,6 +791,29 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Simplifies binary comparisons with semantically-equal expressions: + * 1) Replace '<=>' with 'true' literal. + * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. + * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. + */ +object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + // True with equality + case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral + case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => + TrueLiteral + case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + + // False with inequality + case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + } + } +} + /** * Simplifies conditional expressions (if / case). */ @@ -893,12 +921,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { } /** - * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] - * that were defined in the projection. + * Pushes [[Filter]] operators through many operators iff: + * 1) the operator is deterministic + * 2) the predicate is deterministic and the operator will not change any of rows. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { +object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // SPARK-13473: We can't push the predicate down when the underlying projection output non- // deterministic field(s). Non-deterministic expressions are essentially stateful. This @@ -915,41 +944,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe }) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - } -} - -/** - * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference - * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. - */ -object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, g: Generate) => - // Predicates that reference attributes produced by the `Generate` operator cannot - // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => - cond.references.subsetOf(g.child.outputSet) && cond.deterministic - } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, - g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) - } else { - filter - } - } -} - -/** - * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only - * non-aggregate attributes (typically literals or grouping expressions). - */ -object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate: Aggregate) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression @@ -975,6 +970,72 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } else { filter } + + case filter @ Filter(condition, child) + if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] => + // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.deterministic + } + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = child.output + val newGrandChildren = child.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet)) + Filter(newCond, grandchild) + } + val newChild = child.withNewChildren(newGrandChildren) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } + + case filter @ Filter(condition, e @ Except(left, _)) => + pushDownPredicate(filter, e.left) { predicate => + e.copy(left = Filter(predicate, left)) + } + + // two filters should be combine together by other rules + case filter @ Filter(_, f: Filter) => filter + // should not push predicates through sample, or will generate different results. + case filter @ Filter(_, s: Sample) => filter + // TODO: push predicates through expand + case filter @ Filter(_, e: Expand) => filter + + case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + pushDownPredicate(filter, u.child) { predicate => + u.withNewChildren(Seq(Filter(predicate, u.child))) + } + } + + private def pushDownPredicate( + filter: Filter, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + // TODO: non-deterministic predicates could be pushed through some operators that do not change + // the rows. + val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => + cond.deterministic && cond.references.subsetOf(grandchild.outputSet) + } + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f35d87ebbd9..00656191354f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -69,6 +69,9 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case BroadcastHint(child) => + collectProjectsAndFilters(child) + case other => (None, Nil, other, Map.empty) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0a11574f441c..d4447ca32d5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -312,18 +312,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { def cleanArg(arg: Any): Any = arg match { + // Children are checked using sameResult above. + case tn: TreeNode[_] if containsChild(tn) => null case e: Expression => cleanExpression(e).canonicalized case other => other } productIterator.map { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanArg(e) case s: Option[_] => s.map(cleanArg) case s: Seq[_] => s.map(cleanArg) case m: Map[_, _] => m.mapValues(cleanArg) - case other => other + case other => cleanArg(other) }.toSeq } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3353beb09c5..d4fc9e4da944 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + override protected def validConstraints: Set[Expression] = { children .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) - .reduce(_ intersect _) + .reduce(merge(_, _)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 0f65028261b8..cde8bd5b9614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -57,7 +57,7 @@ object StringUtils { * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL * @param names the names list to be filtered * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will - * follows regular expression convention, case insensitive match and white spaces + * follow regular expression convention, case insensitive match and white spaces * on both ends will be ignored * @return the filtered names list in order */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 8207d64798bd..711e8707116c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -196,12 +196,11 @@ object RandomDataGenerator { case ShortType => randomNumeric[Short]( rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) - case ArrayType(elementType, containsNull) => { + case ArrayType(elementType, containsNull) => forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } - } - case MapType(keyType, valueType, valueContainsNull) => { + case MapType(keyType, valueType, valueContainsNull) => for ( keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- @@ -221,8 +220,7 @@ object RandomDataGenerator { keys.zip(values).toMap } } - } - case StructType(fields) => { + case StructType(fields) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => forType(field.dataType, nullable = field.nullable, rand) } @@ -232,8 +230,7 @@ object RandomDataGenerator { } else { None } - } - case udt: UserDefinedType[_] => { + case udt: UserDefinedType[_] => val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand) // Because random data generator at here returns scala value, we need to // convert it to catalyst value to call udt's deserialize. @@ -253,7 +250,6 @@ object RandomDataGenerator { } else { None } - } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: @@ -277,7 +273,7 @@ object RandomDataGenerator { val fields = mutable.ArrayBuffer.empty[Any] schema.fields.foreach { f => f.dataType match { - case ArrayType(childType, nullable) => { + case ArrayType(childType, nullable) => val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { null } else { @@ -294,10 +290,8 @@ object RandomDataGenerator { arr } fields += data - } - case StructType(children) => { + case StructType(children) => fields += randomRow(rand, StructType(children)) - } case _ => val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) assert(generator.isDefined, "Unsupported type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index fbcac09ce223..f961fe3292be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -149,6 +149,15 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { // Tables // -------------------------------------------------------------------------- + test("the table type of an external table should be EXTERNAL_TABLE") { + val catalog = newBasicCatalog() + val table = + newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE) + catalog.createTable("db2", table, ignoreIfExists = false) + val actual = catalog.getTable("db2", "external_table1") + assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE) + } + test("drop table") { val catalog = newBasicCatalog() assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) @@ -272,31 +281,37 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) - catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2))) resetState() val catalog2 = newBasicCatalog() assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2))) - catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + catalog2.dropPartitions( + "db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) assert(catalog2.listPartitions("db2", "tbl2").isEmpty) } test("drop partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + catalog.dropPartitions( + "does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) } intercept[AnalysisException] { - catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "does_not_exist", Seq(), ignoreIfNotExists = false) } } test("drop partitions that do not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) } - catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) } test("get partition") { @@ -538,8 +553,12 @@ abstract class CatalogTestUtils { identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL_TABLE, storage = storageFormat, - schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")), - partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string"))) + schema = Seq( + CatalogColumn("col1", "int"), + CatalogColumn("col2", "string"), + CatalogColumn("a", "int"), + CatalogColumn("b", "string")), + partitionColumnNames = Seq("a", "b")) } def newFunc(name: String, database: Option[String] = None): CatalogFunction = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 1850dc815608..426273e1e3e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -233,10 +233,9 @@ class SessionCatalogSuite extends SparkFunSuite { intercept[AnalysisException] { catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true) } - // Table does not exist - intercept[AnalysisException] { - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) - } + // If the table does not exist, we do not issue an exception. Instead, we output an error log + // message to console when ignoreIfNotExists is set to false. + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true) } @@ -497,19 +496,25 @@ class SessionCatalogSuite extends SparkFunSuite { val sessionCatalog = new SessionCatalog(externalCatalog) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), + ignoreIfNotExists = false) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) // Drop partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") sessionCatalog.dropPartitions( - TableIdentifier("tbl2"), Seq(part2.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2"), + Seq(part2.spec), + ignoreIfNotExists = false) assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) // Drop multiple partitions at once sessionCatalog.createPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec, part2.spec), + ignoreIfNotExists = false) assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) } @@ -517,11 +522,15 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfNotExists = false) + TableIdentifier("tbl1", Some("does_not_exist")), + Seq(), + ignoreIfNotExists = false) } intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfNotExists = false) + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false) } } @@ -529,10 +538,14 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false) } catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = true) + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = true) } test("get partition") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala new file mode 100644 index 000000000000..7cd038570bbd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("Constant Folding", FixedPoint(50), + NullPropagation, + ConstantFolding, + BooleanSimplification, + BinaryComparisonSimplification, + PruneFilters) :: Nil + } + + val nullableRelation = LocalRelation('a.int.withNullability(true)) + val nonNullableRelation = LocalRelation('a.int.withNullability(false)) + + test("Preserve nullable exprs in general") { + for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) { + val plan = nullableRelation.where(e).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + } + + test("Preserve non-deterministic exprs") { + val plan = nonNullableRelation + .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + + test("Nullable Simplification Primitive: <=>") { + val plan = nullableRelation.select('a <=> 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze + comparePlans(actual, correctAnswer) + } + + test("Non-Nullable Simplification Primitive") { + val plan = nonNullableRelation + .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation + .select( + Alias(TrueLiteral, "(a = a)")(), + Alias(TrueLiteral, "(a <=> a)")(), + Alias(TrueLiteral, "(a <= a)")(), + Alias(TrueLiteral, "(a >= a)")(), + Alias(FalseLiteral, "(a < a)")(), + Alias(FalseLiteral, "(a > a)")()) + .analyze + comparePlans(actual, correctAnswer) + } + + test("Expression Normalization") { + val plan = nonNullableRelation.where( + 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a && + DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a)) + .analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation.analyze + comparePlans(actual, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 2248e03b2fc5..52b574c0e63c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - PushPredicateThroughProject, + PushDownPredicate, ColumnPruning, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b84ae7c5bb6a..df7529d83f7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(10), SamplePushDown, CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, CollapseProject) :: Nil } @@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a === 3) + .select('a, 'b) .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze @@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a + 1 < 3) + .select('a, 'b) .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) .where('c === 2L || 'aa > 4) .analyze @@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where("s" === "s") + .select('a, 'b) .groupBy('a)('a, count('b) as 'c, "s" as 'd) .where('c === 2L) .analyze @@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("broadcast hint") { + val originalQuery = BroadcastHint(testRelation) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("union") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.where('a === 2L), + testRelation2.where('d === 2L))) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("intersect") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Intersect(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Intersect( + testRelation.where('a === 2L), + testRelation2.where('d === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("except") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Except(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Except( + testRelation.where('a === 2L), + testRelation2) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e2f8146beee7..c1ebf8b09e08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, ReorderJoin, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index 741bc113cfcd..fdde89d079bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -61,6 +61,20 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("simplify Like into startsWith and EndsWith") { + val originalQuery = + testRelation + .where(('a like "abc\\%def") || ('a like "abc%def")) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where(('a like "abc\\%def") || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("simplify Like into Contains") { val originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 14fb72a8a343..d8cfec539149 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest { Batch("Filter Pushdown and Pruning", Once, CombineFilters, PruneFilters, - PushPredicateThroughProject, + PushDownPredicate, PushPredicateThroughJoin) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 49c1353efb63..81cc6b123cdd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) + + val a = resolveColumn(tr1, "a") + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) + + val b = resolveColumn(tr1, "b") + verifyConstraints(tr1 + .where('a.attr > 10 && 'b.attr < 10) + .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) } test("propagating constraints in intersect") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 37941cf34e74..467f76193cfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union} import org.apache.spark.sql.catalyst.util._ /** @@ -61,4 +61,9 @@ class SameResultSuite extends SparkFunSuite { test("sorts") { assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) } + + test("union") { + assertSameResult(Union(Seq(testRelation, testRelation2)), + Union(Seq(testRelation2, testRelation))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index c2633a9f8cd4..086547c793e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -60,7 +60,7 @@ public long durationMs() { /** * Initializes from array of iterators of InternalRow. */ - public abstract void init(int index, Iterator iters[]); + public abstract void init(int index, Iterator[] iters); /** * Append a row to currentRows. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 6cc2fda5871d..ea37a08ab5f5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -27,6 +27,7 @@ import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.DataTypes; @@ -114,57 +115,6 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader } } - /** - * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. - */ - public boolean nextBoolean() { - if (!useDictionary) { - return dataColumn.readBoolean(); - } else { - return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); - } - } - - public int nextInt() { - if (!useDictionary) { - return dataColumn.readInteger(); - } else { - return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); - } - } - - public long nextLong() { - if (!useDictionary) { - return dataColumn.readLong(); - } else { - return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); - } - } - - public float nextFloat() { - if (!useDictionary) { - return dataColumn.readFloat(); - } else { - return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); - } - } - - public double nextDouble() { - if (!useDictionary) { - return dataColumn.readDouble(); - } else { - return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); - } - } - - public Binary nextBinary() { - if (!useDictionary) { - return dataColumn.readBytes(); - } else { - return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); - } - } - /** * Advances to the next value. Returns true if the value is non-null. */ @@ -200,8 +150,26 @@ void readBatch(int total, ColumnVector column) throws IOException { ColumnVector dictionaryIds = column.reserveDictionaryIds(total); defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - decodeDictionaryIds(rowId, num, column, dictionaryIds); + + if (column.hasDictionary() || (rowId == 0 && + (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { + // Column vector supports lazy decoding of dictionary values so just set the dictionary. + // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some + // non-dictionary encoded values have already been added). + column.setDictionary(dictionary); + } else { + decodeDictionaryIds(rowId, num, column, dictionaryIds); + } } else { + if (column.hasDictionary() && rowId != 0) { + // This batch already has dictionary encoded values but this new page is not. The batch + // does not support a mix of dictionary and not so we will decode the dictionary. + decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + } column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: @@ -246,11 +214,45 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: + if (column.dataType() == DataTypes.IntegerType || + DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ShortType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + case INT64: + if (column.dataType() == DataTypes.LongType || + DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + case FLOAT: + for (int i = rowId; i < rowId + num; ++i) { + column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + } + break; + case DOUBLE: - case BINARY: - column.setDictionary(dictionary); + for (int i = rowId; i < rowId + num; ++i) { + column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + } break; case INT96: if (column.dataType() == DataTypes.TimestampType) { @@ -263,6 +265,16 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, throw new NotImplementedException(); } break; + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; case FIXED_LEN_BYTE_ARRAY: // DecimalType written in the legacy mode if (DecimalType.is32BitDecimalType(column.dataType())) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0b276e6c773d..ff1f6680a718 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -912,6 +912,11 @@ public void setDictionary(Dictionary dictionary) { this.dictionary = dictionary; } + /** + * Returns true if this column has a dictionary. + */ + public boolean hasDictionary() { return this.dictionary != null; } + /** * Reserve a integer column for ids of dictionary. */ @@ -926,6 +931,13 @@ public ColumnVector reserveDictionaryIds(int capacity) { return dictionaryIds; } + /** + * Returns the underlying integer column for ids of dictionary. + */ + public ColumnVector getDictionaryIds() { + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java similarity index 83% rename from sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java rename to sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java index 8ff7b6549b5f..c7c6e3868f9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java @@ -19,7 +19,6 @@ import org.apache.spark.annotation.Experimental; import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.sql.Dataset; import org.apache.spark.sql.TypedColumn; import org.apache.spark.sql.execution.aggregate.TypedAverage; import org.apache.spark.sql.execution.aggregate.TypedCount; @@ -28,7 +27,7 @@ /** * :: Experimental :: - * Type-safe functions available for {@link Dataset} operations in Java. + * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. * * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. * @@ -43,7 +42,7 @@ public class typed { * * @since 2.0.0 */ - public static TypedColumn avg(MapFunction f) { + public static TypedColumn avg(MapFunction f) { return new TypedAverage(f).toColumnJava(); } @@ -52,7 +51,7 @@ public static TypedColumn avg(MapFunction f) { * * @since 2.0.0 */ - public static TypedColumn count(MapFunction f) { + public static TypedColumn count(MapFunction f) { return new TypedCount(f).toColumnJava(); } @@ -61,7 +60,7 @@ public static TypedColumn count(MapFunction f) { * * @since 2.0.0 */ - public static TypedColumn sum(MapFunction f) { + public static TypedColumn sum(MapFunction f) { return new TypedSumDouble(f).toColumnJava(); } @@ -70,7 +69,7 @@ public static TypedColumn sum(MapFunction f) { * * @since 2.0.0 */ - public static TypedColumn sumLong(MapFunction f) { + public static TypedColumn sumLong(MapFunction f) { return new TypedSumLong(f).toColumnJava(); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index d7f71bd4b089..1343e81569cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -178,10 +178,13 @@ class ContinuousQueryManager(sqlContext: SQLContext) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } + var nextSourceId = 0L val logicalPlan = df.logicalPlan.transform { case StreamingRelation(dataSource, _, output) => // Materialize source to avoid creating it in every batch - val source = dataSource.createSource() + val metadataPath = s"$checkpointLocation/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 // We still need to use the previous `output` instead of `source.schema` as attributes in // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2f6d8d109fee..e216945fbe5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -1493,6 +1492,8 @@ class Dataset[T] private[sql]( * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. * + * For Java API, use [[randomSplitAsList]]. + * * @group typedrel * @since 2.0.0 */ @@ -1510,6 +1511,20 @@ class Dataset[T] private[sql]( }.toArray } + /** + * Returns a Java list that contains randomly split [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * + * @group typedrel + * @since 2.0.0 + */ + def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = { + val values = randomSplit(weights, seed) + java.util.Arrays.asList(values : _*) + } + /** * Randomly splits this [[Dataset]] with the provided weights. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3de8aa02766d..232ed827083f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -108,6 +108,15 @@ class SparkSqlAstBuilder extends AstBuilder { Option(ctx.key).map(visitTablePropertyKey)) } + /** + * Create a [[ShowCreateTableCommand]] logical plan + */ + override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) { + ShowCreateTableCommand( + ctx.tableIdentifier.table.getText, + Option(ctx.tableIdentifier.db).map(_.getText)) + } + /** * Create a [[RefreshTable]] logical plan. */ @@ -143,10 +152,7 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { val options = ctx.explainOption.asScala if (options.exists(_.FORMATTED != null)) { - logWarning("EXPLAIN FORMATTED option is ignored.") - } - if (options.exists(_.LOGICAL != null)) { - logWarning("EXPLAIN LOGICAL option is ignored.") + logWarning("Unsupported operation: EXPLAIN FORMATTED option") } // Create the explain comment. @@ -182,7 +188,9 @@ class SparkSqlAstBuilder extends AstBuilder { } } - /** Type to keep track of a table header. */ + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) /** @@ -206,7 +214,7 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) if (external) { - logWarning("EXTERNAL option is not supported.") + throw new ParseException("Unsupported operation: EXTERNAL option", ctx) } val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText @@ -366,6 +374,22 @@ class SparkSqlAstBuilder extends AstBuilder { } } + /** + * Create a [[DropTable]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.PURGE != null) { + throw new ParseException("Unsupported operation: PURGE option", ctx) + } + if (ctx.REPLICATION != null) { + throw new ParseException("Unsupported operation: REPLICATION clause", ctx) + } + DropTable( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXISTS != null, + ctx.VIEW != null) + } + /** * Create a [[AlterTableRename]] command. * @@ -378,7 +402,8 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { AlterTableRename( visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to)) + visitTableIdentifier(ctx.to), + ctx.VIEW != null) } /** @@ -394,7 +419,8 @@ class SparkSqlAstBuilder extends AstBuilder { ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableSetProperties( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList)) + visitTablePropertyList(ctx.tablePropertyList), + ctx.VIEW != null) } /** @@ -411,7 +437,8 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableUnsetProperties( visitTableIdentifier(ctx.tableIdentifier), visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, - ctx.EXISTS != null) + ctx.EXISTS != null, + ctx.VIEW != null) } /** @@ -475,12 +502,14 @@ class SparkSqlAstBuilder extends AstBuilder { * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec * }}} * - * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning + * ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning * is associated with physical tables */ override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } // Create partition spec to location mapping. val specsAndLocs = if (ctx.partitionSpec.isEmpty) { ctx.partitionSpecLocation.asScala.map { @@ -496,8 +525,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableAddPartition( visitTableIdentifier(ctx.tableIdentifier), specsAndLocs, - ctx.EXISTS != null)( - command(ctx)) + ctx.EXISTS != null) } /** @@ -510,11 +538,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitExchangeTablePartition( ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableExchangePartition( - visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...") } /** @@ -530,8 +555,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableRenamePartition( visitTableIdentifier(ctx.tableIdentifier), visitNonOptionalPartitionSpec(ctx.from), - visitNonOptionalPartitionSpec(ctx.to))( - command(ctx)) + visitNonOptionalPartitionSpec(ctx.to)) } /** @@ -548,13 +572,16 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitDropTablePartitions( ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + if (ctx.PURGE != null) { + throw new AnalysisException(s"Operation not allowed: PURGE") + } AlterTableDropPartition( visitTableIdentifier(ctx.tableIdentifier), ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), - ctx.EXISTS != null, - ctx.PURGE != null)( - command(ctx)) + ctx.EXISTS != null) } /** @@ -567,10 +594,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitArchiveTablePartition( ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableArchivePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...") } /** @@ -583,10 +608,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitUnarchiveTablePartition( ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableUnarchivePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...") } /** @@ -607,10 +630,7 @@ class SparkSqlAstBuilder extends AstBuilder { case s: GenericFileFormatContext => (Seq.empty[String], Option(s.identifier.getText)) case s: TableFileFormatContext => - val elements = Seq(s.inFmt, s.outFmt) ++ - Option(s.serdeCls).toSeq ++ - Option(s.inDriver).toSeq ++ - Option(s.outDriver).toSeq + val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq (elements.map(string), None) } AlterTableSetFileFormat( @@ -645,10 +665,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableTouch( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...") } /** @@ -660,11 +677,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableCompact( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - string(ctx.STRING))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...") } /** @@ -676,10 +689,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableMerge( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE") } /** @@ -774,22 +784,6 @@ class SparkSqlAstBuilder extends AstBuilder { .map(_.identifier.getText)) } - /** - * Create a skew specification. This contains three components: - * - The Skewed Columns - * - Values for which are skewed. The size of each entry must match the number of skewed columns. - * - A store in directory flag. - */ - override def visitSkewSpec( - ctx: SkewSpecContext): (Seq[String], Seq[Seq[String]], Boolean) = withOrigin(ctx) { - val skewedValues = if (ctx.constantList != null) { - Seq(visitConstantList(ctx.constantList)) - } else { - visitNestedConstantList(ctx.nestedConstantList) - } - (visitIdentifierList(ctx.identifierList), skewedValues, ctx.DIRECTORIES != null) - } - /** * Convert a nested constants list into a sequence of string sequences. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c4594f0480e7..447dbe701815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -446,8 +446,11 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns - val haveTooManyFields = numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields - !willFallback && !haveTooManyFields + val hasTooManyOutputFields = + numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + val hasTooManyInputFields = + plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 0a5a72c52a37..253592028c7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -454,7 +454,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + ctx.addMutableState(hashMapClassName, hashMapTerm, "") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -467,6 +467,7 @@ case class TungstenAggregate( s""" ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index 7a18d0afce6b..c39a78da6f9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.TypedColumn +import org.apache.spark.sql.{Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator @@ -27,28 +27,20 @@ import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// -class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { - val numeric = implicitly[Numeric[OUT]] - override def zero: OUT = numeric.zero - override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) - override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) - override def finish(reduction: OUT): OUT = reduction - - // TODO(ekl) java api support once this is exposed in scala -} - - class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { override def zero: Double = 0.0 override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction + override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]() + override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) - def toColumnJava(): TypedColumn[IN, java.lang.Double] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Double]] + + def toColumnJava: TypedColumn[IN, java.lang.Double] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] } } @@ -59,11 +51,14 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + // Java api support def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) - def toColumnJava(): TypedColumn[IN, java.lang.Long] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Long]] + + def toColumnJava: TypedColumn[IN, java.lang.Long] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] } } @@ -76,11 +71,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + // Java api support def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) - def toColumnJava(): TypedColumn[IN, java.lang.Long] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Long]] + def toColumnJava: TypedColumn[IN, java.lang.Long] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]] } } @@ -93,10 +90,12 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D (b1._1 + b2._1, b1._2 + b2._2) } + override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]() + override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) - def toColumnJava(): TypedColumn[IN, java.lang.Double] = { - toColumn(ExpressionEncoder(), ExpressionEncoder()) - .asInstanceOf[TypedColumn[IN, java.lang.Double]] + def toColumnJava: TypedColumn[IN, java.lang.Double] = { + toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index aba500ad8de2..344aaff348e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -400,7 +400,7 @@ case class Range( sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) - .mapPartitionsWithIndex((i, _) => { + .mapPartitionsWithIndex { (i, _) => val partitionStart = (i * numElements) / numSlices * step + start val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start def getSafeMargin(bi: BigInt): Long = @@ -444,7 +444,7 @@ case class Range( unsafeRow } } - }) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 78664baa569d..7cde04b62619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -38,7 +38,7 @@ private[columnar] trait ColumnAccessor { def hasNext: Boolean - def extractTo(row: MutableRow, ordinal: Int) + def extractTo(row: MutableRow, ordinal: Int): Unit protected def underlyingBuffer: ByteBuffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 9a173367f406..d30655e0c4a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -28,12 +28,12 @@ private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ - def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false) + def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false): Unit /** * Appends `row(ordinal)` to the column builder. */ - def appendFrom(row: InternalRow, ordinal: Int) + def appendFrom(row: InternalRow, ordinal: Int): Unit /** * Column statistics information diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 3fd2a93d2926..d84daff7a222 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -422,6 +422,39 @@ case class ShowTablePropertiesCommand( } } +/** + * A command for users to get the DDL of an existing table + * The syntax of using this command in SQL is: + * {{{ + * SHOW CREATE TABLE tableIdentifier + * }}} + */ +case class ShowCreateTableCommand( + tableName: String, + databaseName: Option[String]) + extends RunnableCommand{ + + // The result of SHOW CREATE TABLE is the whole string of DDL command + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("DDL", StringType, nullable = false)::Nil) + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val ddl = catalog.showCreateTable(TableIdentifier(tableName, databaseName)) + if (ddl.contains("CLUSTERED BY") || ddl.contains("SKEWED BY") || ddl.contains("STORED BY")) { + Seq(Row("WARN: This DDL is not supported by Spark SQL natively, " + + "because it contains 'CLUSTERED BY', 'SKEWED BY' or 'STORED BY' clause."), + Row(ddl)) + } else { + Seq(Row(ddl)) + } + } + +} + /** * A command for users to list all of the registered functions. * The syntax of using this command in SQL is: @@ -483,20 +516,38 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { - case Some(info) => - val result = - Row(s"Function: ${info.getName}") :: - Row(s"Class: ${info.getClassName}") :: - Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil - - if (isExtended) { - result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") - } else { - result - } + // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. + functionName.toLowerCase match { + case "<>" => + Row(s"Function: $functionName") :: + Row(s"Usage: a <> b - Returns TRUE if a is not equal to b") :: Nil + case "!=" => + Row(s"Function: $functionName") :: + Row(s"Usage: a != b - Returns TRUE if a is not equal to b") :: Nil + case "between" => + Row(s"Function: between") :: + Row(s"Usage: a [NOT] BETWEEN b AND c - " + + s"evaluate if a is [not] in between b and c") :: Nil + case "case" => + Row(s"Function: case") :: + Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + s"When a = b, returns c; when a = d, return e; else return f") :: Nil + case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { + case Some(info) => + val result = + Row(s"Function: ${info.getName}") :: + Row(s"Class: ${info.getClassName}") :: + Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + + if (isExtended) { + result :+ + Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + } else { + result + } - case None => Seq(Row(s"Function: $functionName not found.")) + case None => Seq(Row(s"Function: $functionName not found.")) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 20779d68e0fd..fc37a142cda1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,15 +17,19 @@ package org.apache.spark.sql.execution.command +import scala.util.control.NonFatal + import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ + // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL @@ -176,26 +180,48 @@ case class DescribeDatabase( } /** - * A command that renames a table/view. + * Drops a table/view from the metastore and removes it if it is cached. * * The syntax of this command is: * {{{ - * ALTER TABLE table1 RENAME TO table2; - * ALTER VIEW view1 RENAME TO view2; + * DROP TABLE [IF EXISTS] table_name; + * DROP VIEW [IF EXISTS] [db_name.]view_name; * }}} */ -case class AlterTableRename( - oldName: TableIdentifier, - newName: TableIdentifier) - extends RunnableCommand { +case class DropTable( + tableName: TableIdentifier, + ifExists: Boolean, + isView: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - catalog.invalidateTable(oldName) - catalog.renameTable(oldName, newName) + if (!catalog.tableExists(tableName)) { + if (!ifExists) { + val objectName = if (isView) "View" else "Table" + logError(s"$objectName '${tableName.quotedString}' does not exist") + } + } else { + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + try { + sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) + } catch { + case NonFatal(e) => log.warn(s"${e.getMessage}", e) + } + catalog.invalidateTable(tableName) + catalog.dropTable(tableName, ifExists) + } Seq.empty[Row] } - } /** @@ -209,11 +235,13 @@ case class AlterTableRename( */ case class AlterTableSetProperties( tableName: TableIdentifier, - properties: Map[String, String]) + properties: Map[String, String], + isView: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) val newProperties = table.properties ++ properties if (DDLUtils.isDatasourceTable(newProperties)) { @@ -239,11 +267,13 @@ case class AlterTableSetProperties( case class AlterTableUnsetProperties( tableName: TableIdentifier, propKeys: Seq[String], - ifExists: Boolean) + ifExists: Boolean, + isView: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -303,53 +333,94 @@ case class AlterTableSerDeProperties( } /** - * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. + * Add Partition in ALTER TABLE: add the table partitions. + * * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, * EXCEPT that it is ILLEGAL to specify a LOCATION clause. * An error message will be issued if the partition exists, unless 'ifNotExists' is true. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * }}} */ case class AlterTableAddPartition( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], - ifNotExists: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + ifNotExists: Boolean) + extends RunnableCommand { + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table add partition is not allowed for tables defined using the datasource API") + } + val parts = partitionSpecsAndLocs.map { case (spec, location) => + // inherit table storage format (possibly except for location) + CatalogTablePartition(spec, table.storage.copy(locationUri = location)) + } + catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists) + Seq.empty[Row] + } + +} + +/** + * Alter a table partition's spec. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; + * }}} + */ case class AlterTableRenamePartition( tableName: TableIdentifier, oldPartition: TablePartitionSpec, - newPartition: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + newPartition: TablePartitionSpec) + extends RunnableCommand { -case class AlterTableExchangePartition( - fromTableName: TableIdentifier, - toTableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.renamePartitions( + tableName, Seq(oldPartition), Seq(newPartition)) + Seq.empty[Row] + } + +} /** - * Drop Partition in ALTER TABLE/VIEW: to drop a particular partition for a table/view. + * Drop Partition in ALTER TABLE: to drop a particular partition for a table. + * * This removes the data and metadata for this partition. * The data is actually moved to the .Trash/Current directory if Trash is configured, * unless 'purge' is true, but the metadata is completely lost. * An error message will be issued if the partition does not exist, unless 'ifExists' is true. * Note: purge is always false when the target is a view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * }}} */ case class AlterTableDropPartition( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], - ifExists: Boolean, - purge: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + ifExists: Boolean) + extends RunnableCommand { -case class AlterTableArchivePartition( - tableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table drop partition is not allowed for tables defined using the datasource API") + } + catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists) + Seq.empty[Row] + } -case class AlterTableUnarchivePartition( - tableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging +} case class AlterTableSetFileFormat( tableName: TableIdentifier, @@ -408,22 +479,6 @@ case class AlterTableSetLocation( } -case class AlterTableTouch( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec])(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableCompact( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - compactType: String)(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableMerge( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec])(sql: String) - extends NativeDDLCommand(sql) with Logging - case class AlterTableChangeCol( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], @@ -462,5 +517,24 @@ private object DDLUtils { def isDatasourceTable(table: CatalogTable): Boolean = { isDatasourceTable(table.properties) } + + /** + * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view, + * issue an exception [[AnalysisException]]. + */ + def verifyAlterTableType( + catalog: SessionCatalog, + tableIdentifier: TableIdentifier, + isView: Boolean): Unit = { + catalog.getTableMetadataOption(tableIdentifier).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead") + case _ => + }) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala new file mode 100644 index 000000000000..0b419851746a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} + +/** + * A command to create a table with the same definition of the given existing table. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE TABLE [IF NOT EXISTS] [db_name.]table_name + * LIKE [other_db_name.]existing_table_name + * }}} + */ +case class CreateTableLike( + targetTable: TableIdentifier, + sourceTable: TableIdentifier, + ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (!catalog.tableExists(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") + } + if (catalog.isTemporaryTable(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'") + } + + val tableToCreate = catalog.getTableMetadata(sourceTable).copy( + identifier = targetTable, + tableType = CatalogTableType.MANAGED_TABLE, + createTime = System.currentTimeMillis, + lastAccessTime = -1).withNewStorage(locationUri = None) + + catalog.createTable(tableToCreate, ifNotExists) + Seq.empty[Row] + } +} + + +// TODO: move the rest of the table commands from ddl.scala to this file + +/** + * A command to create a table. + * + * Note: This is currently used only for creating Hive tables. + * This is not intended for temporary tables. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) + * [STORED AS DIRECTORIES] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} + */ +case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.createTable(table, ifNotExists) + Seq.empty[Row] + } + +} + + +/** + * A command that renames a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ +case class AlterTableRename( + oldName: TableIdentifier, + newName: TableIdentifier, + isView: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, oldName, isView) + catalog.invalidateTable(oldName) + catalog.renameTable(oldName, newName) + Seq.empty[Row] + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f55cedb1b6d7..10fde152ab2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -123,36 +123,58 @@ case class DataSource( } } - /** Returns a source that can be used to continually read data. */ - def createSource(): Source = { + private def inferFileFormatSchema(format: FileFormat): StructType = { + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val allPaths = caseInsensitiveOptions.get("path") + val globbedPaths = allPaths.toSeq.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified) + }.toArray + + val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) + userSpecifiedSchema.orElse { + format.inferSchema( + sqlContext, + caseInsensitiveOptions, + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException("Unable to infer schema. It must be specified manually.") + } + } + + /** Returns the name and schema of the source that can be used to continually read data. */ + def sourceSchema(): (String, StructType) = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource(sqlContext, userSpecifiedSchema, className, options) + s.sourceSchema(sqlContext, userSpecifiedSchema, className, options) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - val metadataPath = caseInsensitiveOptions.getOrElse("metadataPath", s"$path/_metadata") + (s"FileSource[$path]", inferFileFormatSchema(format)) + case _ => + throw new UnsupportedOperationException( + s"Data source $className does not support streamed reading") + } + } - val allPaths = caseInsensitiveOptions.get("path") - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualified) - }.toArray + /** Returns a source that can be used to continually read data. */ + def createSource(metadataPath: String): Source = { + providingClass.newInstance() match { + case s: StreamSourceProvider => + s.createSource(sqlContext, metadataPath, userSpecifiedSchema, className, options) - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) - val dataSchema = userSpecifiedSchema.orElse { - format.inferSchema( - sqlContext, - caseInsensitiveOptions, - fileCatalog.allFiles()) - }.getOrElse { - throw new AnalysisException("Unable to infer schema. It must be specified manually.") - } + case format: FileFormat => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val path = caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + + val dataSchema = inferFileFormatSchema(format) def dataFrameBuilder(files: Array[String]): DataFrame = { Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8c183317f610..ac3c52e90179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer -import org.apache.spark.TaskContext -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala @@ -35,14 +33,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.DataSourceScan.{INPUT_PATHS, PUSHED_FILTERS} -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.ExecutedCommand -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet /** * Replaces generic operations with specific variants that are designed to work with Spark @@ -110,133 +104,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters, (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil - // Scanning partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) - if t.partitionSchema.nonEmpty => - // We divide the filter expressions into 3 parts - val partitionColumns = AttributeSet( - t.partitionSchema.map(c => l.output.find(_.name == c.name).get)) - - // Only pruning the partition keys - val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns)) - - // Only pushes down predicates that do not reference partition keys. - val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) - - // Predicates with both partition keys and attributes - val partitionAndNormalColumnFilters = - filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet - - val selectedPartitions = t.location.listFiles(partitionFilters) - - logInfo { - val total = t.partitionSpec.partitions.length - val selected = selectedPartitions.length - val percentPruned = (1 - selected.toDouble / total.toDouble) * 100 - s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." - } - - // need to add projections from "partitionAndNormalColumnAttrs" in if it is not empty - val partitionAndNormalColumnAttrs = AttributeSet(partitionAndNormalColumnFilters) - val partitionAndNormalColumnProjs = if (partitionAndNormalColumnAttrs.isEmpty) { - projects - } else { - (partitionAndNormalColumnAttrs ++ projects).toSeq - } - - // Prune the buckets based on the pushed filters that do not contain partitioning key - // since the bucketing key is not allowed to use the columns in partitioning key - val bucketSet = getBuckets(pushedFilters, t.bucketSpec) - val scan = buildPartitionedTableScan( - l, - partitionAndNormalColumnProjs, - pushedFilters, - bucketSet, - t.partitionSpec.partitionColumns, - selectedPartitions, - t.options) - - // Add a Projection to guarantee the original projection: - // this is because "partitionAndNormalColumnAttrs" may be different - // from the original "projects", in elements or their ordering - - partitionAndNormalColumnFilters.reduceLeftOption(expressions.And).map(cf => - if (projects.isEmpty || projects == partitionAndNormalColumnProjs) { - // if the original projection is empty, no need for the additional Project either - execution.Filter(cf, scan) - } else { - execution.Project(projects, execution.Filter(cf, scan)) - } - ).getOrElse(scan) :: Nil - - // TODO: The code for planning bucketed/unbucketed/partitioned/unpartitioned tables contains - // a lot of duplication and produces overly complicated RDDs. - - // Scanning non-partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) => - // See buildPartitionedTableScan for the reason that we need to create a shard - // broadcast HadoopConf. - val sharedHadoopConf = SparkHadoopUtil.get.conf - val confBroadcast = - t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - - t.bucketSpec match { - case Some(spec) if t.sqlContext.conf.bucketingEnabled => - val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { - (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { - val bucketed = - t.location - .allFiles() - .filterNot(_.getPath.getName startsWith "_") - .groupBy { f => - BucketingUtils - .getBucketId(f.getPath.getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.getPath}")) - } - - val bucketedDataMap = bucketed.mapValues { bucketFiles => - t.fileFormat.buildInternalScan( - t.sqlContext, - t.dataSchema, - requiredColumns.map(_.name).toArray, - filters, - None, - bucketFiles, - confBroadcast, - t.options).coalesce(1) - } - - val bucketedRDD = new UnionRDD(t.sqlContext.sparkContext, - (0 until spec.numBuckets).map { bucketId => - bucketedDataMap.getOrElse(bucketId, t.sqlContext.emptyResult: RDD[InternalRow]) - }) - bucketedRDD - } - } - - pruneFilterProject( - l, - projects, - filters, - scanBuilder) :: Nil - - case _ => - pruneFilterProject( - l, - projects, - filters, - (a, f) => - t.fileFormat.buildInternalScan( - t.sqlContext, - t.dataSchema, - a.map(_.name).toArray, - f, - None, - t.location.allFiles(), - confBroadcast, - t.options)) :: Nil - } - case l @ LogicalRelation(baseRelation: TableScan, _, _) => execution.DataSourceScan.create( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil @@ -248,218 +115,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case _ => Nil } - private def buildPartitionedTableScan( - logicalRelation: LogicalRelation, - projections: Seq[NamedExpression], - filters: Seq[Expression], - buckets: Option[BitSet], - partitionColumns: StructType, - partitions: Seq[Partition], - options: Map[String, String]): SparkPlan = { - val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] - - // Because we are creating one RDD per partition, we need to have a shared HadoopConf. - // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. - val sharedHadoopConf = SparkHadoopUtil.get.conf - val confBroadcast = - relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - val partitionColumnNames = partitionColumns.fieldNames.toSet - - // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder - // will union all partitions and attach partition values if needed. - val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { - (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { - - relation.bucketSpec match { - case Some(spec) if relation.sqlContext.conf.bucketingEnabled => - val requiredDataColumns = - requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) - - // Builds RDD[Row]s for each selected partition. - val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap { - case Partition(partitionValues, files) => - val bucketed = files.groupBy { f => - BucketingUtils - .getBucketId(f.getPath.getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.getPath}")) - } - - bucketed.map { bucketFiles => - // Don't scan any partition columns to save I/O. Here we are being optimistic and - // assuming partition columns data stored in data files are always consistent with - // those partition values encoded in partition directory paths. - val dataRows = relation.fileFormat.buildInternalScan( - relation.sqlContext, - relation.dataSchema, - requiredDataColumns.map(_.name).toArray, - filters, - buckets, - bucketFiles._2, - confBroadcast, - options) - - // Merges data values with partition values. - bucketFiles._1 -> mergeWithPartitionValues( - requiredColumns, - requiredDataColumns, - partitionColumns, - partitionValues, - dataRows) - } - } - - val bucketedDataMap: Map[Int, Seq[RDD[InternalRow]]] = - perPartitionRows.groupBy(_._1).mapValues(_.map(_._2)) - - val bucketed = new UnionRDD(relation.sqlContext.sparkContext, - (0 until spec.numBuckets).map { bucketId => - bucketedDataMap.get(bucketId).map(i => i.reduce(_ ++ _).coalesce(1)).getOrElse { - relation.sqlContext.emptyResult: RDD[InternalRow] - } - }) - bucketed - - case _ => - val requiredDataColumns = - requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) - - // Builds RDD[Row]s for each selected partition. - val perPartitionRows = partitions.map { - case Partition(partitionValues, files) => - val dataRows = relation.fileFormat.buildInternalScan( - relation.sqlContext, - relation.dataSchema, - requiredDataColumns.map(_.name).toArray, - filters, - buckets, - files, - confBroadcast, - options) - - // Merges data values with partition values. - mergeWithPartitionValues( - requiredColumns, - requiredDataColumns, - partitionColumns, - partitionValues, - dataRows) - } - new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) - } - } - } - - // Create the scan operator. If needed, add Filter and/or Project on top of the scan. - // The added Filter/Project is on top of the unioned RDD. We do not want to create - // one Filter/Project for every partition. - val sparkPlan = pruneFilterProject( - logicalRelation, - projections, - filters, - scanBuilder) - - sparkPlan - } - - /** - * Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can - * either come from `input` (columns scanned from the data source) or from the partitioning - * values (data from `partitionValues`). This is done *once* per physical partition. When - * the column is from `input`, it just references the same underlying column. When using - * partition columns, the column is populated once. - * TODO: there's probably a cleaner way to do this. - */ - private def projectedColumnBatch( - input: ColumnarBatch, - requiredColumns: Seq[Attribute], - dataColumns: Seq[Attribute], - partitionColumnSchema: StructType, - partitionValues: InternalRow) : ColumnarBatch = { - val result = ColumnarBatch.allocate(StructType.fromAttributes(requiredColumns)) - var resultIdx = 0 - var inputIdx = 0 - - while (resultIdx < requiredColumns.length) { - val attr = requiredColumns(resultIdx) - if (inputIdx < dataColumns.length && requiredColumns(resultIdx) == dataColumns(inputIdx)) { - result.setColumn(resultIdx, input.column(inputIdx)) - inputIdx += 1 - } else { - require(partitionColumnSchema.fields.count(_.name == attr.name) == 1) - var partitionIdx = 0 - partitionColumnSchema.fields.foreach { f => { - if (f.name.equals(attr.name)) { - ColumnVectorUtils.populate(result.column(resultIdx), partitionValues, partitionIdx) - } - partitionIdx += 1 - }} - } - resultIdx += 1 - } - result - } - - private def mergeWithPartitionValues( - requiredColumns: Seq[Attribute], - dataColumns: Seq[Attribute], - partitionColumnSchema: StructType, - partitionValues: InternalRow, - dataRows: RDD[InternalRow]): RDD[InternalRow] = { - // If output columns contain any partition column(s), we need to merge scanned data - // columns and requested partition columns to form the final result. - if (requiredColumns != dataColumns) { - // Builds `AttributeReference`s for all partition columns so that we can use them to project - // required partition columns. Note that if a partition column appears in `requiredColumns`, - // we should use the `AttributeReference` in `requiredColumns`. - val partitionColumns = { - val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap - partitionColumnSchema.toAttributes.map { a => - requiredColumnMap.getOrElse(a.name, a) - } - } - - val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Object]) => { - // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and - // `UnsafeProjection`. Because the projection may also adjust column order. - val mutableJoinedRow = new JoinedRow() - val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues) - val unsafeProjection = - UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) - - // If we are returning batches directly, we need to augment them with the partitioning - // columns. We want to do this without a row by row operation. - var columnBatch: ColumnarBatch = null - var mergedBatch: ColumnarBatch = null - - iterator.map { input => { - if (input.isInstanceOf[InternalRow]) { - unsafeProjection(mutableJoinedRow( - input.asInstanceOf[InternalRow], unsafePartitionValues)) - } else { - require(input.isInstanceOf[ColumnarBatch]) - val inputBatch = input.asInstanceOf[ColumnarBatch] - if (inputBatch != mergedBatch) { - mergedBatch = inputBatch - columnBatch = projectedColumnBatch(inputBatch, requiredColumns, - dataColumns, partitionColumnSchema, partitionValues) - } - columnBatch.setNumRows(inputBatch.numRows()) - columnBatch - } - }} - } - - // This is an internal RDD whose call site the user should not be concerned with - // Since we create many of these (one per partition), the time spent on computing - // the call site may add up. - Utils.withDummyCallSite(dataRows.sparkContext) { - new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) - }.asInstanceOf[RDD[InternalRow]] - } else { - dataRows - } - } - // Get the bucket ID based on the bucketing values. // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { @@ -472,57 +127,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { bucketIdGeneration(mutableRow).getInt(0) } - // Get the bucket BitSet by reading the filters that only contains bucketing keys. - // Note: When the returned BitSet is None, no pruning is possible. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - private def getBuckets( - filters: Seq[Expression], - bucketSpec: Option[BucketSpec]): Option[BitSet] = { - - if (bucketSpec.isEmpty || - bucketSpec.get.numBuckets == 1 || - bucketSpec.get.bucketColumnNames.length != 1) { - // None means all the buckets need to be scanned - return None - } - - // Just get the first because bucketing pruning only works when the column has one column - val bucketColumnName = bucketSpec.get.bucketColumnNames.head - val numBuckets = bucketSpec.get.numBuckets - val matchedBuckets = new BitSet(numBuckets) - matchedBuckets.clear() - - filters.foreach { - case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - // Because we only convert In to InSet in Optimizer when there are more than certain - // items. So it is possible we still get an In expression here that needs to be pushed - // down. - case expressions.In(a: Attribute, list) - if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => - val hSet = list.map(e => e.eval(EmptyRow)) - hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e))) - case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, null)) - case _ => - } - - logInfo { - val selected = matchedBuckets.cardinality() - val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100 - s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions." - } - - // None means all the buckets need to be scanned - if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) - } - // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 988c785dbe61..468e101fedb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.rdd.{RDD, SqlNewHadoopRDDState} +import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -37,7 +37,6 @@ case class PartitionedFile( } } - /** * A collection of files that should be read as a single task possibly from multiple partitioned * directories. @@ -50,7 +49,7 @@ class FileScanRDD( @transient val sqlContext: SQLContext, readFunction: (PartitionedFile) => Iterator[InternalRow], @transient val filePartitions: Seq[FilePartition]) - extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { @@ -65,17 +64,17 @@ class FileScanRDD( if (files.hasNext) { val nextFile = files.next() logInfo(s"Reading File $nextFile") - SqlNewHadoopRDDState.setInputFileName(nextFile.filePath) + InputFileNameHolder.setInputFileName(nextFile.filePath) currentIterator = readFunction(nextFile) hasNext } else { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() false } } override def close() = { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index aa1f76450c23..80a9156ddcdc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -55,15 +55,7 @@ import org.apache.spark.sql.sources._ */ private[sql] object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) - if (files.fileFormat.toString == "TestFileFormat" || - files.fileFormat.isInstanceOf[parquet.DefaultSource] || - files.fileFormat.toString == "ORC" || - files.fileFormat.toString == "LibSVM" || - files.fileFormat.isInstanceOf[csv.DefaultSource] || - files.fileFormat.isInstanceOf[text.DefaultSource] || - files.fileFormat.isInstanceOf[json.DefaultSource]) && - files.sqlContext.conf.useFileScan => + case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: // - partition keys only - used to prune directories to read @@ -72,18 +64,28 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) + // The attribute name of predicate could be different than the one in schema in case of + // case insensitive, we should change them to match the one in schema, so we donot need to + // worry about case sensitivity anymore. + val normalizedFilters = filters.map { e => + e transform { + case a: AttributeReference => + a.withName(l.output.find(_.semanticEquals(a)).get.name) + } + } + val partitionColumns = l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(filters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver) // Partition keys are not available in the statistics of the files. - val dataFilters = filters.filter(_.references.intersect(partitionSet).isEmpty) + val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) // Predicates with both partition keys and attributes need to be evaluated after the scan. val afterScanFilters = filterSet -- partitionKeyFilters diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index e31380e17d40..889c0204f8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.util.Utils /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala deleted file mode 100644 index 6ddb218a2263..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ /dev/null @@ -1,284 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.text.SimpleDateFormat -import java.util.Date - -import scala.reflect.ClassTag - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} - -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} - -private[spark] class SqlNewHadoopPartition( - rddId: Int, - val index: Int, - rawSplit: InputSplit with Writable) - extends SparkPartition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = 41 * (41 + rddId) + index -} - -/** - * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). - * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. - * 1. A shared broadcast Hadoop Configuration. - * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side - * to the shared Hadoop Configuration. - * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side - * and the executor side to the shared Hadoop Configuration. - * - * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. - */ -private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sqlContext: SQLContext, - broadcastedConf: Broadcast[SerializableConfiguration], - @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], - initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[Void, V]], - valueClass: Class[V]) - extends RDD[V](sqlContext.sparkContext, Nil) with Logging { - - protected def getJob(): Job = { - val conf = broadcastedConf.value.value - // "new Job" will make a copy of the conf. Then, it is - // safe to mutate conf properties with initLocalJobFuncOpt - // and initDriverSideJobFuncOpt. - val newJob = Job.getInstance(conf) - initLocalJobFuncOpt.map(f => f(newJob)) - newJob - } - - def getConf(isDriverSide: Boolean): Configuration = { - val job = getJob() - if (isDriverSide) { - initDriverSideJobFuncOpt.map(f => f(job)) - } - job.getConfiguration - } - - private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient protected val jobId = new JobID(jobTrackerId, id) - - override def getPartitions: Array[SparkPartition] = { - val conf = getConf(isDriverSide = true) - val inputFormat = inputFormatClass.newInstance - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = new JobContextImpl(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { - val iter = new Iterator[V] { - val split = theSplit.asInstanceOf[SqlNewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf(isDriverSide = false) - - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) - val existingBytesRead = inputMetrics.bytesRead - - // Sets the thread local variable for the file's name - split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDDState.unsetInputFileName() - } - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } - - // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. - // If we do a coalesce, however, we are likely to compute multiple partitions in the same - // task and in the same thread, in which case we need to avoid override values written by - // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { - getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) - } - } - - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private[this] var reader: RecordReader[Void, V] = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) - - private[this] var havePair = false - private[this] var finished = false - - override def hasNext: Boolean = { - if (context.isInterrupted()) { - throw new TaskKilledException - } - if (!finished && !havePair) { - finished = !reader.nextKeyValue - if (finished) { - // Close and release the reader here; close() will also be called when the task - // completes, but for tasks that read from many files, it helps to release the - // resources early. - close() - } - havePair = !finished - } - !finished - } - - override def next(): V = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - if (!finished) { - inputMetrics.incRecordsReadInternal(1) - } - if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { - updateBytesRead() - } - reader.getCurrentValue - } - - private def close() { - if (reader != null) { - SqlNewHadoopRDDState.unsetInputFileName() - // Close the reader and release it. Note: it's very important that we don't close the - // reader more than once, since that exposes us to MAPREDUCE-5918 when running against - // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic - // corruption issues when reading compressed input. - try { - reader.close() - } catch { - case e: Exception => - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } finally { - reader = null - } - if (getBytesReadCallback.isDefined) { - updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - } - } - } - } - iter - } - - override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { - val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value - val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) - } catch { - case e : Exception => - logDebug("Failed to use InputSplit#getLocationInfo.", e) - None - } - case None => None - } - locs.getOrElse(split.getLocations.filter(_ != "localhost")) - } - - override def persist(storageLevel: StorageLevel): this.type = { - if (storageLevel.deserialized) { - logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + - " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + - " Use a map transformation to make copies of the records.") - } - super.persist(storageLevel) - } - - /** - * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to - * the given function rather than the index of the partition. - */ - private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], - f: (InputSplit, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[SparkPartition] = firstParent[T].partitions - - override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { - val partition = split.asInstanceOf[SqlNewHadoopPartition] - val inputSplit = partition.serializableHadoopSplit.value - f(inputSplit, firstParent[T].iterator(split, context)) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index d2bbf196cb6a..815d1d01ef34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,9 +33,9 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{SerializableConfiguration, Utils} /** A container for all the details required when writing to a table. */ case class WriteRelation( @@ -247,19 +247,16 @@ private[sql] class DefaultWriterContainer( // If anything below fails, we should abort the task. try { - while (iterator.hasNext) { - val internalRow = iterator.next() - writer.writeInternal(internalRow) - } - - commitTask() + Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val internalRow = iterator.next() + writer.writeInternal(internalRow) + } + commitTask() + }(catchBlock = abortTask()) } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) - abortTask() - throw new SparkException("Task failed while writing rows.", cause) + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } def commitTask(): Unit = { @@ -413,37 +410,37 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. var currentWriter: OutputWriter = null try { - var currentKey: UnsafeRow = null - while (sortedIterator.next()) { - val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] - if (currentKey != nextKey) { - if (currentWriter != null) { - currentWriter.close() - currentWriter = null + Utils.tryWithSafeFinallyAndFailureCallbacks { + var currentKey: UnsafeRow = null + while (sortedIterator.next()) { + val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow] + if (currentKey != nextKey) { + if (currentWriter != null) { + currentWriter.close() + currentWriter = null + } + currentKey = nextKey.copy() + logDebug(s"Writing partition: $currentKey") + + currentWriter = newOutputWriter(currentKey, getPartitionString) } - currentKey = nextKey.copy() - logDebug(s"Writing partition: $currentKey") - - currentWriter = newOutputWriter(currentKey, getPartitionString) + currentWriter.writeInternal(sortedIterator.getValue) + } + if (currentWriter != null) { + currentWriter.close() + currentWriter = null } - currentWriter.writeInternal(sortedIterator.getValue) - } - if (currentWriter != null) { - currentWriter.close() - currentWriter = null - } - commitTask() - } catch { - case cause: Throwable => - logError("Aborting task.", cause) - // call failure callbacks first, so we could have a chance to cleanup the writer. - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause) + commitTask() + }(catchBlock = { if (currentWriter != null) { currentWriter.close() } abortTask() - throw new SparkException("Task failed while writing rows.", cause) + }) + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 95de02cf5c18..7b9d3b605a89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -22,8 +22,7 @@ import java.nio.charset.StandardCharsets import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.{CompressionCodecs, ParseModes} -private[sql] class CSVOptions( - @transient private val parameters: Map[String, String]) +private[sql] class CSVOptions(@transient private val parameters: Map[String, String]) extends Logging with Serializable { private def getChar(paramName: String, default: Char): Char = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 34fcbdf87133..06a371b88bc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -133,37 +133,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - /** - * This supports to eliminate unneeded columns before producing an RDD - * containing all of its tuples as Row objects. This reads all the tokens of each line - * and then drop unneeded tokens without casting and type-checking by mapping - * both the indices produced by `requiredColumns` and the ones of tokens. - */ - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: Filter before calling buildInternalScan. - val csvFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - - val csvOptions = new CSVOptions(options) - val pathsString = csvFiles.map(_.getPath.toUri.toString) - val header = dataSchema.fields.map(_.name) - val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString) - val rows = CSVRelation.parseCsv(tokenizedRdd, dataSchema, requiredColumns, csvOptions) - - val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get)) - rows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredDataSchema) - iterator.map(unsafeProjection) - } - } - private def baseRdd( sqlContext: SQLContext, options: CSVOptions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b7ff5f72427a..065c8572b06a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -251,12 +251,12 @@ object JdbcUtils extends Logging { def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { + df.schema.fields foreach { field => val name = field.name val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") - }} + } if (sb.length < 2) "" else sb.substring(2) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 42cd25a18c95..7364a1dc0658 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -28,18 +28,16 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet class DefaultSource extends FileFormat with DataSourceRegister { @@ -93,35 +91,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: Filter files for all formats before calling buildInternalScan. - val jsonFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - - val parsedOptions: JSONOptions = new JSONOptions(options) - val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) - val rows = JacksonParser.parse( - createBaseRdd(sqlContext, jsonFiles), - requiredDataSchema, - columnNameOfCorruptRecord, - parsedOptions) - - rows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredDataSchema) - iterator.map(unsafeProjection) - } - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 8bc53bae6c13..aeee2600a19e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -54,9 +54,9 @@ object JacksonParser extends Logging { * with an array. */ def convertRootField( - factory: JsonFactory, - parser: JsonParser, - schema: DataType): Any = { + factory: JsonFactory, + parser: JsonParser, + schema: DataType): Any = { import com.fasterxml.jackson.core.JsonToken._ (parser.getCurrentToken, schema) match { case (START_ARRAY, st: StructType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala new file mode 100644 index 000000000000..00352f23ae66 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.parquet.hadoop.metadata.CompressionCodecName + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf + +/** + * Options for the Parquet data source. + */ +class ParquetOptions( + @transient private val parameters: Map[String, String], + @transient private val sqlConf: SQLConf) + extends Logging with Serializable { + + import ParquetOptions._ + + /** + * Compression codec to use. By default use the value specified in SQLConf. + * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. + */ + val compressionCodec: String = { + val codecName = parameters.getOrElse("compression", sqlConf.parquetCompressionCodec).toLowerCase + if (!shortParquetCompressionCodecNames.contains(codecName)) { + val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) + throw new IllegalArgumentException(s"Codec [$codecName] " + + s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") + } + shortParquetCompressionCodecNames(codecName).name() + } +} + + +object ParquetOptions { + // The parquet compression short names + private val shortParquetCompressionCodecNames = Map( + "none" -> CompressionCodecName.UNCOMPRESSED, + "uncompressed" -> CompressionCodecName.UNCOMPRESSED, + "snappy" -> CompressionCodecName.SNAPPY, + "gzip" -> CompressionCodecName.GZIP, + "lzo" -> CompressionCodecName.LZO) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 5ad95e4b9eac..b91e892f8f21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.{List => JList} import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ @@ -27,23 +26,19 @@ import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} -import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Partition => SparkPartition, SparkException} -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow @@ -53,8 +48,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{AtomicType, DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.SerializableConfiguration private[sql] class DefaultSource extends FileFormat @@ -74,6 +68,8 @@ private[sql] class DefaultSource options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + val parquetOptions = new ParquetOptions(options, sqlContext.sessionState.conf) + val conf = ContextUtil.getConfiguration(job) val committerClass = @@ -84,24 +80,11 @@ private[sql] class DefaultSource if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + - classOf[ParquetOutputCommitter].getCanonicalName) + classOf[ParquetOutputCommitter].getCanonicalName) } else { logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) } - val compressionCodec: Option[String] = options - .get("compression") - .map { codecName => - // Validate if given compression codec is supported or not. - val shortParquetCompressionCodecNames = ParquetRelation.shortParquetCompressionCodecNames - if (!shortParquetCompressionCodecNames.contains(codecName.toLowerCase)) { - val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase) - throw new IllegalArgumentException(s"Codec [$codecName] " + - s"is not available. Available codecs are ${availableCodecs.mkString(", ")}.") - } - codecName.toLowerCase - } - conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, @@ -136,14 +119,7 @@ private[sql] class DefaultSource sqlContext.conf.writeLegacyParquetFormat.toString) // Sets compression scheme - conf.set( - ParquetOutputFormat.COMPRESSION, - ParquetRelation - .shortParquetCompressionCodecNames - .getOrElse( - compressionCodec - .getOrElse(sqlContext.conf.parquetCompressionCodec.toLowerCase), - CompressionCodecName.UNCOMPRESSED).name()) + conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodec) new OutputWriterFactory { override def newInstance( @@ -269,12 +245,12 @@ private[sql] class DefaultSource } /** - * Returns whether the reader will the rows as batch or not. + * Returns whether the reader will return the rows as batch or not. */ override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = { val conf = SQLContext.getActive().get.conf - conf.useFileScan && conf.parquetVectorizedReaderEnabled && - conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && + conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && schema.forall(_.dataType.isInstanceOf[AtomicType]) } @@ -393,110 +369,6 @@ private[sql] class DefaultSource } } } - - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - allFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - - // Parquet row group size. We will use this value as the value for - // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value - // of these flags are smaller than the parquet row group size. - val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value) - - // Create the function to set variable Parquet confs at both driver and executor side. - val initLocalJobFuncOpt = - ParquetRelation.initializeLocalJobFunc( - requiredColumns, - filters, - dataSchema, - parquetBlockSize, - useMetadataCache, - parquetFilterPushDown, - assumeBinaryIsString, - assumeInt96IsTimestamp) _ - - val inputFiles = splitFiles(allFiles).data.toArray - - // Create the function to set input paths at the driver side. - val setInputPaths = - ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ - - val allPrimitiveTypes = dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val inputFormatCls = if (sqlContext.conf.parquetVectorizedReaderEnabled - && allPrimitiveTypes) { - classOf[VectorizedParquetInputFormat] - } else { - classOf[ParquetInputFormat[InternalRow]] - } - - Utils.withDummyCallSite(sqlContext.sparkContext) { - new SqlNewHadoopRDD( - sqlContext = sqlContext, - broadcastedConf = broadcastedConf, - initDriverSideJobFuncOpt = Some(setInputPaths), - initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = inputFormatCls, - valueClass = classOf[InternalRow]) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as '/' - // (which does happen in some S3N credentials), we need to use the string returned by the - // URI of the path to create a new Path. - val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) - new FileStatus( - f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) - }.toSeq - - private def escapePathUserInfo(path: Path): Path = { - val uri = path.toUri - new Path(new URI( - uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, - uri.getQuery, uri.getFragment)) - } - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = new ParquetInputFormat[InternalRow] { - override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext) - } - } - - val jobContext = new JobContextImpl(getConf(isDriverSide = true), jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition( - id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) - } - } - } - } - } -} - -/** - * The ParquetInputFormat that create VectorizedParquetRecordReader. - */ -final class VectorizedParquetInputFormat extends ParquetInputFormat[InternalRow] { - override def createRecordReader( - inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): ParquetRecordReader[InternalRow] = { - new VectorizedParquetRecordReader().asInstanceOf[ParquetRecordReader[InternalRow]] - } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. @@ -917,12 +789,4 @@ private[sql] object ParquetRelation extends Logging { // should be removed after this issue is fixed. } } - - // The parquet compression short names - val shortParquetCompressionCodecNames = Map( - "none" -> CompressionCodecName.UNCOMPRESSED, - "uncompressed" -> CompressionCodecName.UNCOMPRESSED, - "snappy" -> CompressionCodecName.SNAPPY, - "gzip" -> CompressionCodecName.GZIP, - "lzo" -> CompressionCodecName.LZO) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 99459ba1d377..94ecb7a28663 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -19,14 +19,10 @@ package org.apache.spark.sql.execution.datasources.text import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, NullWritable, Text} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -35,7 +31,6 @@ import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFile import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet /** * A data source for reading text files. @@ -88,45 +83,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - verifySchema(dataSchema) - - val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration - val paths = inputFiles - .filterNot(_.getPath.getName startsWith "_") - .map(_.getPath) - .sortBy(_.toUri) - - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) - } - - sqlContext.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) - .mapPartitions { iter => - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) - - iter.map { case (_, line) => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow - } - } - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index e3d554c2de20..a8f854136c1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -50,10 +51,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode( - canJoinKeyFitWithinLong, - rewriteKeyExpr(buildKeys), - buildPlan.output) + val mode = HashedRelationBroadcastMode(buildKeys) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -68,7 +66,7 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } } @@ -105,7 +103,7 @@ case class BroadcastHashJoin( ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.getMemorySize()); + | incPeakExecutionMemory($relationTerm.estimatedSize()); """.stripMargin) (broadcastRelation, relationTerm) } @@ -118,15 +116,13 @@ case class BroadcastHashJoin( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (canJoinKeyFitWithinLong) { + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { // generate the join key as Long - val expr = rewriteKeyExpr(streamedKeys).head - val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + val ev = streamedKeys.head.gen(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8f45d5712672..d6feedc27244 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException - -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { self: SparkPlan => @@ -59,9 +55,15 @@ trait HashJoin { case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) + protected lazy val (buildKeys, streamedKeys) = { + require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), + "Join keys from two sides should have same types") + val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + buildSide match { + case BuildLeft => (lkeys, rkeys) + case BuildRight => (rkeys, lkeys) + } } /** @@ -69,7 +71,7 @@ trait HashJoin { * * If not, returns the original expressions. */ - def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { var keyExpr: Expression = null var width = 0 keys.foreach { e => @@ -84,17 +86,8 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 - // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same - // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys - // with two same ints have hash code 0, we rotate the bits of second one. - val rotated = if (e.dataType == IntegerType) { - // (e >>> 15) | (e << 17) - BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) - } else { - e - } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -105,17 +98,11 @@ trait HashJoin { keyExpr :: Nil } - protected lazy val canJoinKeyFitWithinLong: Boolean = { - val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) - val key = rewriteKeyExpr(buildKeys) - sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] - } - protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) + UnsafeProjection.create(buildKeys) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) + UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 5ccb435686f2..0427db4e3bf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,24 +18,22 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} -import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} -import org.apache.spark.util.collection.CompactBuffer /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns matched rows. * @@ -74,51 +72,36 @@ private[execution] sealed trait HashedRelation { */ def asReadOnlyCopy(): HashedRelation - /** - * Returns the size of used memory. - */ - def getMemorySize: Long = 1L // to make the test happy - /** * Release any used resources. */ - def close(): Unit = {} - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { - out.writeInt(serialized.length) // Write the length of serialized bytes first - out.write(serialized) - } - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def readBytes(in: ObjectInput): Array[Byte] = { - val serializedSize = in.readInt() // Read the length of serialized bytes first - val bytes = new Array[Byte](serializedSize) - in.readFully(bytes) - bytes - } + def close(): Unit } private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. - * - * Note: The caller should make sure that these InternalRow are different objects. */ def apply( - canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int = 64): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int = 64, + taskMemoryManager: TaskMemoryManager = null): HashedRelation = { + val mm = Option(taskMemoryManager).getOrElse { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } - if (canJoinKeyFitWithinLong) { - LongHashedRelation(input, keyGenerator, sizeEstimate) + if (key.length == 1 && key.head.dataType == LongType) { + LongHashedRelation(input, key, sizeEstimate, mm) } else { - UnsafeHashedRelation( - input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } @@ -133,22 +116,17 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with KnownSizeEstimation with Externalizable { + extends HashedRelation with Externalizable { private[joins] def this() = this(0, null) // Needed for serialization override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() - override def asReadOnlyCopy(): UnsafeHashedRelation = + override def asReadOnlyCopy(): UnsafeHashedRelation = { new UnsafeHashedRelation(numFields, binaryMap) - - override def getMemorySize: Long = { - binaryMap.getTotalMemoryConsumption } - override def estimatedSize: Long = { - binaryMap.getTotalMemoryConsumption - } + override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption // re-used in get()/getValue() var resultRow = new UnsafeRow(numFields) @@ -276,20 +254,10 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - keyGenerator: UnsafeProjection, - sizeEstimate: Int): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { - val taskMemoryManager = if (TaskContext.get() != null) { - TaskContext.get().taskMemoryManager() - } else { - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -300,6 +268,7 @@ private[joins] object UnsafeHashedRelation { pageSizeBytes) // Create a mapping of buildKeys -> rows + val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] @@ -322,143 +291,420 @@ private[joins] object UnsafeHashedRelation { } /** - * An interface for a hashed relation that the key is a Long. + * An append-only hash map mapping from key of Long to UnsafeRow. + * + * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array + * (`page`) in this format: + * + * [bytes of row1][address1][bytes of row2][address1] ... + * + * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key + * could have multiple values. the address at the end of last value for every key is 0. + * + * The keys and addresses of their values could be stored in two modes: + * + * 1) sparse mode: the keys and addresses are stored in `array` as: + * + * [key1][address1][key2][address2]...[] + * + * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 + * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address + * hash collision. + * + * 2) dense mode: all the addresses are packed into a single array of long, as: + * + * [address1] [address2] ... + * + * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is + * determined by `key1 - minKey`. + * + * The map is created as sparse mode, then key-value could be appended into it. Once finish + * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * to probe. + * + * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ */ -private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Iterator[InternalRow] = { - get(key.getLong(0)) +private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable { + + // Whether the keys are stored in dense mode or not. + private var isDense = false + + // The minimum key + private var minKey = Long.MaxValue + + // The maxinum key + private var maxKey = Long.MinValue + + // The array to store the key and offset of UnsafeRow in the page. + // + // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... + // Dense mode: [offset1 | size1] [offset2 | size2] + private var array: Array[Long] = null + private var mask: Int = 0 + + // The page to store all bytes of UnsafeRow and the pointer to next rows. + // [row1][pointer1] [row2][pointer2] + private var page: Array[Byte] = null + + // Current write cursor in the page. + private var cursor = Platform.BYTE_ARRAY_OFFSET + + // The total number of values of all keys. + private var numValues = 0 + + // The number of unique keys. + private var numKeys = 0 + + // needed by serializer + def this() = { + this( + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0), + 0) } - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + + private def acquireMemory(size: Long): Unit = { + // do not support spilling + val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + if (got < size) { + freeMemory(got) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " + + s"got $got bytes") + } } -} -private[joins] final class GeneralLongHashedRelation( - private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) - extends LongHashedRelation with Externalizable { + private def freeMemory(size: Long): Unit = { + mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + private def init(): Unit = { + if (mm != null) { + var n = 1 + while (n < capacity) n *= 2 + acquireMemory(n * 2 * 8 + (1 << 20)) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + page = new Array[Byte](1 << 20) // 1M bytes + } + } + + init() + + def spill(size: Long, trigger: MemoryConsumer): Long = 0L + + /** + * Returns whether all the keys are unique. + */ + def keyIsUnique: Boolean = numKeys == numValues - override def keyIsUnique: Boolean = false + /** + * Returns total memory consumption. + */ + def getTotalMemoryConsumption: Long = array.length * 8 + page.length - override def asReadOnlyCopy(): GeneralLongHashedRelation = - new GeneralLongHashedRelation(hashTable) + /** + * Returns the first slot of array that store the keys (sparse mode). + */ + private def firstSlot(key: Long): Int = { + val h = key * 0x9E3779B9L + (h ^ (h >> 32)).toInt & mask + } - override def get(key: Long): Iterator[InternalRow] = { - val rows = hashTable.get(key) - if (rows != null) { - rows.toIterator + /** + * Returns the next probe in the array. + */ + private def nextSlot(pos: Int): Int = (pos + 2) & mask + + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { + val offset = address >>> 32 + val size = address & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + resultRow + } + + /** + * Returns the single UnsafeRow for given key, or null if not found. + */ + def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >= 0 && key <= maxKey && array(idx) > 0) { + return getRow(array(idx), resultRow) + } } else { - null + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return getRow(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } } + null } - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + /** + * Returns an interator of UnsafeRow for multiple linked values. + */ + private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + new Iterator[UnsafeRow] { + var addr = address + override def hasNext: Boolean = addr != 0 + override def next(): UnsafeRow = { + val offset = addr >>> 32 + val size = addr & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + addr = Platform.getLong(page, offset + size) + resultRow + } + } } - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + /** + * Returns an iterator for all the values for the given key, or null if no value found. + */ + def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >=0 && key <= maxKey && array(idx) > 0) { + return valueIter(array(idx), resultRow) + } + } else { + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return valueIter(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null } -} -/** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ -private[joins] final class LongArrayRelation( - private var numFields: Int, - private var start: Long, - private var offsets: Array[Int], - private var sizes: Array[Int], - private var bytes: Array[Byte] - ) extends LongHashedRelation with Externalizable { + /** + * Appends the key and row into this map. + */ + def append(key: Long, row: UnsafeRow): Unit = { + if (key < minKey) { + minKey = key + } + if (key > maxKey) { + maxKey = key + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, 0L, null, null, null) + // There is 8 bytes for the pointer to next value + if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + val used = page.length + if (used * 2L > (1L << 31)) { + sys.error("Can't allocate a page that is larger than 2G") + } + acquireMemory(used * 2) + val newPage = new Array[Byte](used * 2) + System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + page = newPage + freeMemory(used) + } - override def keyIsUnique: Boolean = true + // copy the bytes of UnsafeRow + val offset = cursor + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + cursor += row.getSizeInBytes + Platform.putLong(page, cursor, 0) + cursor += 8 + numValues += 1 + updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + } - override def asReadOnlyCopy(): LongArrayRelation = { - new LongArrayRelation(numFields, start, offsets, sizes, bytes) + /** + * Update the address in array for given key. + */ + private def updateIndex(key: Long, address: Long): Unit = { + var pos = firstSlot(key) + while (array(pos) != key && array(pos + 1) != 0) { + pos = nextSlot(pos) + } + if (array(pos + 1) == 0) { + // this is the first value for this key, put the address in array. + array(pos) = key + array(pos + 1) = address + numKeys += 1 + if (numKeys * 4 > array.length) { + // reach half of the capacity + growArray() + } + } else { + // there are some values for this key, put the address in the front of them. + val pointer = (address >>> 32) + (address & 0xffffffffL) + Platform.putLong(page, pointer, array(pos + 1)) + array(pos + 1) = address + } } - override def getMemorySize: Long = { - offsets.length * 4 + sizes.length * 4 + bytes.length + private def growArray(): Unit = { + var old_array = array + val n = array.length + numKeys = 0 + acquireMemory(n * 2 * 8) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + var i = 0 + while (i < old_array.length) { + if (old_array(i + 1) > 0) { + updateIndex(old_array(i), old_array(i + 1)) + } + i += 2 + } + old_array = null // release the reference to old array + freeMemory(n * 8) } - override def get(key: Long): Iterator[InternalRow] = { - val row = getValue(key) - if (row != null) { - Seq(row).toIterator - } else { - null + /** + * Try to turn the map into dense mode, which is faster to probe. + */ + def optimize(): Unit = { + val range = maxKey - minKey + // Convert to dense mode if it does not require more memory or could fit within L1 cache + if (range < array.length || range < 1024) { + try { + acquireMemory((range + 1) * 8) + } catch { + case e: SparkException => + // there is no enough memory to convert + return + } + val denseArray = new Array[Long]((range + 1).toInt) + var i = 0 + while (i < array.length) { + if (array(i + 1) > 0) { + val idx = (array(i) - minKey).toInt + denseArray(idx) = array(i + 1) + } + i += 2 + } + val old_length = array.length + array = denseArray + isDense = true + freeMemory(old_length * 8) } } - var resultRow = new UnsafeRow(numFields) - override def getValue(key: Long): InternalRow = { - val idx = (key - start).toInt - if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - resultRow - } else { - null + /** + * Free all the memory acquired by this map. + */ + def free(): Unit = { + if (page != null) { + freeMemory(page.length) + page = null + } + if (array != null) { + freeMemory(array.length * 8) + array = null } } override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(numFields) - out.writeLong(start) - out.writeInt(sizes.length) - var i = 0 - while (i < sizes.length) { - out.writeInt(sizes(i)) - i += 1 + out.writeBoolean(isDense) + out.writeLong(minKey) + out.writeLong(maxKey) + out.writeInt(numKeys) + out.writeInt(numValues) + + out.writeInt(array.length) + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size } - out.writeInt(bytes.length) - out.write(bytes) + + val used = cursor - Platform.BYTE_ARRAY_OFFSET + out.writeInt(used) + out.write(page, 0, used) } override def readExternal(in: ObjectInput): Unit = { - numFields = in.readInt() - resultRow = new UnsafeRow(numFields) - start = in.readLong() + isDense = in.readBoolean() + minKey = in.readLong() + maxKey = in.readLong() + numKeys = in.readInt() + numValues = in.readInt() + val length = in.readInt() - // read sizes of rows - sizes = new Array[Int](length) - offsets = new Array[Int](length) - var i = 0 - var offset = 0 - while (i < length) { - offsets(i) = offset - sizes(i) = in.readInt() - offset += sizes(i) - i += 1 + array = new Array[Long](length) + mask = length - 2 + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + offset += size + } + + val numBytes = in.readInt() + page = new Array[Byte](numBytes) + in.readFully(page) + } +} + +private[joins] class LongHashedRelation( + private var nFields: Int, + private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { + + private var resultRow: UnsafeRow = new UnsafeRow(nFields) + + // Needed for serialization (it is public to make Java serialization work) + def this() = this(0, null) + + override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) + + override def estimatedSize: Long = map.getTotalMemoryConsumption + + override def get(key: InternalRow): Iterator[InternalRow] = { + if (key.isNullAt(0)) { + null + } else { + get(key.getLong(0)) + } + } + + override def getValue(key: InternalRow): InternalRow = { + if (key.isNullAt(0)) { + null + } else { + getValue(key.getLong(0)) } - // read all the bytes - val total = in.readInt() - assert(total == offset) - bytes = new Array[Byte](total) - in.readFully(bytes) + } + + override def get(key: Long): Iterator[InternalRow] = map.get(key, resultRow) + + override def getValue(key: Long): InternalRow = map.getValue(key, resultRow) + + override def keyIsUnique: Boolean = map.keyIsUnique + + override def close(): Unit = { + map.free() + } + + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(nFields) + out.writeObject(map) + } + + override def readExternal(in: ObjectInput): Unit = { + nFields = in.readInt() + resultRow = new UnsafeRow(nFields) + map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } } @@ -466,96 +712,45 @@ private[joins] final class LongArrayRelation( * Create hashed relation with key that is long. */ private[joins] object LongHashedRelation { - - val DENSE_FACTOR = 0.2 - def apply( - input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - // TODO: use LongToBytesMap for better memory efficiency - val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows var numFields = 0 - var keyIsUnique = true - var minKey = Long.MaxValue - var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { + if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) - minKey = math.min(minKey, key) - maxKey = math.max(maxKey, key) - val existingMatchList = hashTable.get(key) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(key, newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += unsafeRow - } - } - - if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 - } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 + map.append(key, unsafeRow) } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - } else { - new GeneralLongHashedRelation(hashTable) } + map.optimize() + new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode( - canJoinKeyFitWithinLong: Boolean, - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) + extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - val generator = UnsafeProjection.create(keys, attributes) - HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) + HashedRelation(rows.iterator, canonicalizedKey, rows.length) } - private lazy val canonicalizedKeys: Seq[Expression] = { - keys.map { e => - BindReferences.bindReference(e.canonicalized, attributes) - } + private lazy val canonicalizedKey: Seq[Expression] = { + key.map { e => e.canonicalized } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => - canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && - canonicalizedKeys == m.canonicalizedKeys + case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index bf8609637928..0c3e3c3fc18a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.memory.MemoryMode +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -57,54 +56,20 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { + private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val context = TaskContext.get() - if (!canJoinKeyFitWithinLong) { - // build BytesToBytesMap - val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) - // This relation is usually used until the end of task. - context.addTaskCompletionListener((t: TaskContext) => - relation.close() - ) - return relation - } - - // try to acquire some memory for the hash table, it could trigger other operator to free some - // memory. The memory acquired here will mostly be used until the end of task. - val memoryManager = context.taskMemoryManager() - var acquired = 0L - var used = 0L + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + // This relation is usually used until the end of task. context.addTaskCompletionListener((t: TaskContext) => - memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) + relation.close() ) - - val copiedIter = iter.map { row => - // It's hard to guess what's exactly memory will be used, we have a rough guess here. - // TODO: use LongToBytesMap instead of HashMap for memory efficiency - // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers - val needed = 150 + row.getSizeInBytes - if (needed > acquired - used) { - val got = memoryManager.acquireExecutionMemory( - Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) - acquired += got - if (got < needed) { - throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + - "hash join, please use sort merge join by setting " + - "spark.sql.join.preferSortMergeJoin=true") - } - } - used += needed - // HashedRelation requires that the UnsafeRow should be separate objects. - row.copy() - } - - HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) + relation } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => - val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) + val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 688e051e1fee..87dd27a2b1ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -59,12 +59,14 @@ class StreamExecution( * Tracks how much data we have processed and committed to the sink or state store from each * input source. */ + @volatile private[sql] var committedOffsets = new StreamProgress /** * Tracks the offsets that are available to be processed, but have not yet be committed to the * sink. */ + @volatile private var availableOffsets = new StreamProgress /** The current batchId or -1 if execution has not yet been initialized. */ @@ -111,7 +113,8 @@ class StreamExecution( /** Returns current status of all the sources. */ override def sourceStatuses: Array[SourceStatus] = { - sources.map(s => new SourceStatus(s.toString, availableOffsets.get(s))).toArray + val localAvailableOffsets = availableOffsets + sources.map(s => new SourceStatus(s.toString, localAvailableOffsets.get(s))).toArray } /** Returns current status of the sink. */ @@ -228,7 +231,7 @@ class StreamExecution( * Queries all of the sources to see if any new data is available. When there is new data the * batchId counter is incremented and a new log entry is written with the newest offsets. */ - private def constructNextBatch(): Boolean = { + private def constructNextBatch(): Unit = { // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). // If we interrupt some thread running Shell.runCommand, we may hit this issue. // As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand" @@ -241,7 +244,15 @@ class StreamExecution( } availableOffsets ++= newData - if (dataAvailable) { + val hasNewData = awaitBatchLock.synchronized { + if (dataAvailable) { + true + } else { + noNewData = true + false + } + } + if (hasNewData) { // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). // If we interrupt some thread running Shell.runCommand, we may hit this issue. // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set @@ -254,15 +265,11 @@ class StreamExecution( } currentBatchId += 1 logInfo(s"Committed offsets for batch $currentBatchId.") - true } else { - noNewData = true awaitBatchLock.synchronized { // Wake up any threads that are waiting for the stream to progress. awaitBatchLock.notifyAll() } - - false } } @@ -353,7 +360,10 @@ class StreamExecution( * least the given `Offset`. This method is indented for use primarily when writing tests. */ def awaitOffset(source: Source, newOffset: Offset): Unit = { - def notDone = !committedOffsets.contains(source) || committedOffsets(source) < newOffset + def notDone = { + val localCommittedOffsets = committedOffsets + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset + } while (notDone) { logInfo(s"Waiting until $newOffset at $source") @@ -365,13 +375,17 @@ class StreamExecution( /** A flag to indicate that a batch has completed with no new data available. */ @volatile private var noNewData = false - override def processAllAvailable(): Unit = { + override def processAllAvailable(): Unit = awaitBatchLock.synchronized { noNewData = false - while (!noNewData) { - awaitBatchLock.synchronized { awaitBatchLock.wait(10000) } - if (streamDeathCause != null) { throw streamDeathCause } + while (true) { + awaitBatchLock.wait(10000) + if (streamDeathCause != null) { + throw streamDeathCause + } + if (noNewData) { + return + } } - if (streamDeathCause != null) { throw streamDeathCause } } override def awaitTermination(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f951dea735d9..d2872e49ce28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.datasources.DataSource object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { - val source = dataSource.createSource() - StreamingRelation(dataSource, source.toString, source.schema.toAttributes) + val (name, schema) = dataSource.sourceSchema() + StreamingRelation(dataSource, name, schema.toAttributes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 351ef404a8e3..3820968324bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val encoder = encoderFor[A] protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + + @GuardedBy("this") protected val batches = new ArrayBuffer[Dataset[A]] + @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) def schema: StructType = encoder.schema @@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = { import sqlContext.implicits._ + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") this.synchronized { currentOffset = currentOffset + 1 - val ds = data.toVector.toDS() - logDebug(s"Adding ds: $ds") batches.append(ds) currentOffset } @@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${output.mkString(",")}]" - override def getOffset: Option[Offset] = if (batches.isEmpty) { - None - } else { - Some(currentOffset) + override def getOffset: Option[Offset] = synchronized { + if (batches.isEmpty) { + None + } else { + Some(currentOffset) + } } /** @@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = batches.slice(startOrdinal, endOrdinal) + val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) */ class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ + @GuardedBy("this") private val batches = new ArrayBuffer[Array[Row]]() /** Returns all rows that are stored in this [[Sink]]. */ @@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { batches.flatten } - def lastBatch: Seq[Row] = batches.last + def lastBatch: Seq[Row] = synchronized { batches.last } def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => @@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { }.mkString("\n") } - override def addBatch(batchId: Long, data: DataFrame): Unit = { + override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { if (batchId == batches.size) { logDebug(s"Committing batch $batchId") batches.append(data.collect()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 1e0a4a5d4ff0..3335755fd3b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -46,12 +46,14 @@ import org.apache.spark.util.Utils * Usage: * To update the data in the state store, the following order of operations are needed. * - * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store - * - store.update(...) + * // get the right store + * - val store = StateStore.get( + * StateStoreId(checkpointLocation, operatorId, partitionId), ..., version, ...) + * - store.put(...) * - store.remove(...) - * - store.commit() // commits all the updates to made with version number + * - store.commit() // commits all the updates to made; the new version will be returned * - store.iterator() // key-value data after last commit as an iterator - * - store.updates() // updates made in the last as an iterator + * - store.updates() // updates made in the last commit as an iterator * * Fault-tolerance model: * - Every set of updates is written to a delta file before committing. @@ -99,7 +101,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or cancelled") + verify(state == UPDATING, "Cannot remove after already committed or aborted") val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) @@ -109,7 +111,7 @@ private[state] class HDFSBackedStateStoreProvider( // Value did not exist in previous version and was added already, keep it marked as added allUpdates.put(key, ValueAdded(key, value)) case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => - // Value existed in prev version and updated/removed, mark it as updated + // Value existed in previous version and updated/removed, mark it as updated allUpdates.put(key, ValueUpdated(key, value)) case None => // There was no prior update, so mark this as added or updated according to its presence @@ -122,7 +124,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or cancelled") + verify(state == UPDATING, "Cannot remove after already committed or aborted") val keyIter = mapToUpdate.keySet().iterator() while (keyIter.hasNext) { val key = keyIter.next @@ -146,7 +148,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { - verify(state == UPDATING, "Cannot commit after already committed or cancelled") + verify(state == UPDATING, "Cannot commit after already committed or aborted") try { finalizeDeltaFile(tempDeltaFileStream) @@ -161,8 +163,10 @@ private[state] class HDFSBackedStateStoreProvider( } } - /** Cancel all the updates made on this store. This store will not be usable any more. */ + /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { + verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") + state = ABORTED if (tempDeltaFileStream != null) { tempDeltaFileStream.close() @@ -170,7 +174,7 @@ private[state] class HDFSBackedStateStoreProvider( if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { fs.delete(tempDeltaFile, true) } - logInfo("Canceled ") + logInfo("Aborted") } /** @@ -178,7 +182,8 @@ private[state] class HDFSBackedStateStoreProvider( * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, "Cannot get iterator of store data before committing") + verify(state == COMMITTED, + "Cannot get iterator of store data before committing or after aborting") HDFSBackedStateStoreProvider.this.iterator(newVersion) } @@ -187,7 +192,8 @@ private[state] class HDFSBackedStateStoreProvider( * This can be called only after committing all the updates made in the current thread. */ override def updates(): Iterator[StoreUpdate] = { - verify(state == COMMITTED, "Cannot get iterator of updates before committing") + verify(state == COMMITTED, + "Cannot get iterator of updates before committing or after aborting") allUpdates.values().asScala.toIterator } @@ -223,7 +229,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } /* Internal classes and methods */ @@ -277,7 +283,7 @@ private[state] class HDFSBackedStateStoreProvider( } else { if (!fs.isDirectory(baseDir)) { throw new IllegalStateException( - s"Cannot use ${id.checkpointLocation} for storing state data for $this as" + + s"Cannot use ${id.checkpointLocation} for storing state data for $this as " + s"$baseDir already exists and is not a directory") } } @@ -453,11 +459,11 @@ private[state] class HDFSBackedStateStoreProvider( filesForVersion(files, lastVersion).filter(_.isSnapshot == false) synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => - if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) { + if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { writeSnapshotFile(lastVersion, map) } case None => - // The last map is not loaded, probably some other instance is incharge + // The last map is not loaded, probably some other instance is in charge } } @@ -506,7 +512,6 @@ private[state] class HDFSBackedStateStoreProvider( .lastOption val deltaBatchFiles = latestSnapshotFileBeforeVersion match { case Some(snapshotFile) => - val deltaBatchIds = (snapshotFile.version + 1) to version val deltaFiles = allFiles.filter { file => file.version > snapshotFile.version && file.version <= version @@ -579,4 +584,3 @@ private[state] class HDFSBackedStateStoreProvider( } } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 07f63f928b8f..952150632519 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.streaming.state -import java.util.Timer import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable @@ -51,7 +50,7 @@ trait StateStore { def get(key: UnsafeRow): Option[UnsafeRow] /** Put a new value for a key. */ - def put(key: UnsafeRow, value: UnsafeRow) + def put(key: UnsafeRow, value: UnsafeRow): Unit /** * Remove keys that match the following condition. @@ -63,7 +62,7 @@ trait StateStore { */ def commit(): Long - /** Cancel all the updates that have been made to the store. */ + /** Abort all the updates that have been made to the store. */ def abort(): Unit /** @@ -109,8 +108,8 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate /** * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), - * it also runs a periodic background tasks to do maintenance on the loaded stores. For each - * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of + * it also runs a periodic background task to do maintenance on the loaded stores. For each + * store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of * the store is the active instance. Accordingly, it either keeps it loaded and performs * maintenance, or unloads the store. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index f0f1f3a1a838..e55f63a6c8db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -26,7 +26,7 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex import SQLConf._ - val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + val minDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } @@ -34,4 +34,3 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex private[streaming] object StateStoreConf { val empty = new StateStoreConf() } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 812e1b0a3957..e418217238cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -50,8 +50,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { private val endpointName = "StateStoreCoordinator" /** - * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as - * executors. + * Create a reference to a [[StateStoreCoordinator]] */ def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { @@ -75,7 +74,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { } /** - * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of + * Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of * [[StateStore]]s across all the executors, and get their locations for job scheduling. */ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { @@ -142,5 +141,3 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadS context.reply(true) } } - - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index df3d82c113ca..d708486d8ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -22,12 +22,12 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.SerializableConfiguration /** * An RDD that allows computations to be executed against [[StateStore]]s. It - * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as - * preferred locations. + * uses the [[StateStoreCoordinator]] to get the locations of loaded state stores + * and use that as the preferred locations. */ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index d3e823fdeb30..e96fb9f7550a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -55,6 +55,12 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } _content } + content ++= + UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) } } @@ -118,14 +124,12 @@ private[ui] abstract class ExecutionTable( {failedJobs} }} - {detailCell(executionUIData.physicalPlanDescription)} } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details.nonEmpty) { - + +details ++ - } - def toNodeSeq: Seq[Node] = {

{tableName}

@@ -197,7 +177,7 @@ private[ui] class RunningExecutionTable( showFailedJobs = true) { override protected def header: Seq[String] = - baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs", "Detail") + baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs") } private[ui] class CompletedExecutionTable( @@ -215,7 +195,7 @@ private[ui] class CompletedExecutionTable( showSucceededJobs = true, showFailedJobs = false) { - override protected def header: Seq[String] = baseHeader ++ Seq("Jobs", "Detail") + override protected def header: Seq[String] = baseHeader ++ Seq("Jobs") } private[ui] class FailedExecutionTable( @@ -234,5 +214,5 @@ private[ui] class FailedExecutionTable( showFailedJobs = true) { override protected def header: Seq[String] = - baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs", "Detail") + baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 9cb356f1ca37..7da8379c9aa9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -43,52 +43,65 @@ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression * * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird * - * @tparam I The input type for the aggregation. - * @tparam B The type of the intermediate value of the reduction. - * @tparam O The type of the final output result. + * @tparam IN The input type for the aggregation. + * @tparam BUF The type of the intermediate value of the reduction. + * @tparam OUT The type of the final output result. * @since 1.6.0 */ -abstract class Aggregator[-I, B, O] extends Serializable { +abstract class Aggregator[-IN, BUF, OUT] extends Serializable { /** * A zero value for this aggregation. Should satisfy the property that any b + zero = b. * @since 1.6.0 */ - def zero: B + def zero: BUF /** * Combine two values to produce a new value. For performance, the function may modify `b` and * return it instead of constructing new object for b. * @since 1.6.0 */ - def reduce(b: B, a: I): B + def reduce(b: BUF, a: IN): BUF /** * Merge two intermediate values. * @since 1.6.0 */ - def merge(b1: B, b2: B): B + def merge(b1: BUF, b2: BUF): BUF /** * Transform the output of the reduction. * @since 1.6.0 */ - def finish(reduction: B): O + def finish(reduction: BUF): OUT /** - * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * Specifies the [[Encoder]] for the intermediate value type. + * @since 2.0.0 + */ + def bufferEncoder: Encoder[BUF] + + /** + * Specifies the [[Encoder]] for the final ouput value type. + * @since 2.0.0 + */ + def outputEncoder: Encoder[OUT] + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]]. * operations. * @since 1.6.0 */ - def toColumn( - implicit bEncoder: Encoder[B], - cEncoder: Encoder[O]): TypedColumn[I, O] = { + def toColumn: TypedColumn[IN, OUT] = { + implicit val bEncoder = bufferEncoder + implicit val cEncoder = outputEncoder + val expr = AggregateExpression( TypedAggregateExpression(this), Complete, isDistinct = false) - new TypedColumn[I, O](expr, encoderFor[O]) + new TypedColumn[IN, OUT](expr, encoderFor[OUT]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5bc0034cb029..223122300dbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -154,6 +154,8 @@ object functions { /** * Aggregate function: returns the approximate number of distinct items in a group. * + * @param rsd maximum estimation error allowed (default = 0.05) + * * @group agg_funcs * @since 1.3.0 */ @@ -164,6 +166,8 @@ object functions { /** * Aggregate function: returns the approximate number of distinct items in a group. * + * @param rsd maximum estimation error allowed (default = 0.05) + * * @group agg_funcs * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc6ba1bcfb6d..2f9d63c2e813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -145,12 +144,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val USE_FILE_SCAN = SQLConfigBuilder("spark.sql.sources.fileScan") - .internal() - .doc("Use the new FileScanRDD path for reading HDSF based data sources.") - .booleanConf - .createWithDefault(true) - val PARQUET_SCHEMA_MERGING_ENABLED = SQLConfigBuilder("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -193,7 +186,7 @@ object SQLConf { .stringConf .transform(_.toLowerCase()) .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) - .createWithDefault("gzip") + .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.parquet.filterPushdown") .doc("Enables Parquet filter push-down optimization when set to true.") @@ -481,8 +474,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def useCompression: Boolean = getConf(COMPRESS_CACHED) - def useFileScan: Boolean = getConf(USE_FILE_SCAN) - def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6acb41dd1f7d..4b9bf8daae37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -129,8 +129,17 @@ trait SchemaRelationProvider { * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. */ trait StreamSourceProvider { + + /** Returns the name and schema of the source that can be used to continually read data. */ + def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) + def createSource( sqlContext: SQLContext, + metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source @@ -458,16 +467,6 @@ trait FileFormat { options: Map[String, String], dataSchema: StructType): OutputWriterFactory - def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] - /** * Returns whether this format support returning columnar batch or not. * @@ -594,10 +593,7 @@ class HDFSFileCatalog( } if (partitionPruningPredicates.nonEmpty) { - val predicate = - partitionPruningPredicates - .reduceOption(expressions.And) - .getOrElse(Literal(true)) + val predicate = partitionPruningPredicates.reduce(expressions.And) val boundPredicate = InterpretedPredicate.create(predicate.transform { case a: AttributeReference => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala index bf78be9d9ff8..ba1facf11b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -37,7 +37,7 @@ abstract class ContinuousQueryListener { * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please * don't block this method as it will block your query. */ - def onQueryStarted(queryStarted: QueryStarted) + def onQueryStarted(queryStarted: QueryStarted): Unit /** * Called when there is some status update (ingestion rate updated, etc.) @@ -47,10 +47,10 @@ abstract class ContinuousQueryListener { * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] * is terminated when you are processing [[QueryProgress]]. */ - def onQueryProgress(queryProgress: QueryProgress) + def onQueryProgress(queryProgress: QueryProgress): Unit /** Called when a query is stopped, with or without error */ - def onQueryTerminated(queryTerminated: QueryTerminated) + def onQueryTerminated(queryTerminated: QueryTerminated): Unit } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f26c57b301c3..5abd62cbc245 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -454,6 +454,16 @@ public void testJavaEncoder() { Assert.assertEquals(data, ds.collectAsList()); } + @Test + public void testRandomSplit() { + List data = Arrays.asList("hello", "world", "from", "spark"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + double[] arraySplit = {1, 2, 3}; + + List> randomSplit = ds.randomSplitAsList(arraySplit, 1); + Assert.assertEquals("wrong number of splits", randomSplit.size(), 3); + } + /** * For testing error messages when creating an encoder on a private class. This is done * here since we cannot create truly private classes in Scala. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index 8cb174b9067f..0e49f871de5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -26,6 +26,7 @@ import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; @@ -39,12 +40,10 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { public void testTypedAggregationAnonClass() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); - Dataset> agged = - grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Dataset> agged = grouped.agg(new IntSumOf().toColumn()); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - Dataset> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + Dataset> agged2 = grouped.agg(new IntSumOf().toColumn()) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList( @@ -73,6 +72,16 @@ public Integer merge(Integer b1, Integer b2) { public Integer finish(Integer reduction) { return reduction; } + + @Override + public Encoder bufferEncoder() { + return Encoders.INT(); + } + + @Override + public Encoder outputEncoder() { + return Encoders.INT(); + } } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 86c640552236..e953a6e8ef0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1153,14 +1153,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNonExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => atFirstAgg = !atFirstAgg - } - case _ => { + case _ => if (atFirstAgg) { fail("Should not have operators between the two aggregations") } - } } } @@ -1170,12 +1168,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => if (atFirstAgg) { fail("Should not have back to back Aggregates") } atFirstAgg = true - } case e: ShuffleExchange => atFirstAgg = false case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 08b3389ad918..3a7215ee3972 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -26,74 +27,65 @@ import org.apache.spark.sql.test.SharedSQLContext object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { - override def zero: (Long, Long) = (0, 0) - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { (countAndSum._1 + 1, countAndSum._2 + input._2) } - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { (b1._1 + b2._1, b1._2 + b2._2) } - override def finish(reduction: (Long, Long)): (Long, Long) = reduction + override def bufferEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] + override def outputEncoder: Encoder[(Long, Long)] = Encoders.product[(Long, Long)] } + case class AggData(a: Int, b: String) + object ClassInputAgg extends Aggregator[AggData, Int, Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: Int = 0 - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ override def reduce(b: Int, a: AggData): Int = b + a.a - - /** - * Transform the output of the reduction. - */ override def finish(reduction: Int): Int = reduction - - /** - * Merge two intermediate values - */ override def merge(b1: Int, b2: Int): Int = b1 + b2 + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt } + object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { - /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ override def zero: (Int, AggData) = 0 -> AggData(0, "0") - - /** - * Combine two values to produce a new value. For performance, the function may modify `b` and - * return it instead of constructing new object for b. - */ override def reduce(b: (Int, AggData), a: AggData): (Int, AggData) = (b._1 + 1, a) - - /** - * Transform the output of the reduction. - */ override def finish(reduction: (Int, AggData)): Int = reduction._1 - - /** - * Merge two intermediate values - */ override def merge(b1: (Int, AggData), b2: (Int, AggData)): (Int, AggData) = (b1._1 + b2._1, b1._2) + override def bufferEncoder: Encoder[(Int, AggData)] = Encoders.product[(Int, AggData)] + override def outputEncoder: Encoder[Int] = Encoders.scalaInt } + object NameAgg extends Aggregator[AggData, String, String] { def zero: String = "" - def reduce(b: String, a: AggData): String = a.b + b - def merge(b1: String, b2: String): String = b1 + b2 - def finish(r: String): String = r + override def bufferEncoder: Encoder[String] = Encoders.STRING + override def outputEncoder: Encoder[String] = Encoders.STRING +} + + +class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) + extends Aggregator[IN, OUT, OUT] { + + private val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction + override def bufferEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] + override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -187,6 +179,19 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) } + test("generic typed sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum[(String, Int), Double](_._2.toDouble).toColumn), + ("a", 4.0), ("b", 3.0)) + + checkDataset( + ds.groupByKey(_._1) + .agg(new ParameterizedTypeSum((x: (String, Int)) => x._2.toInt).toColumn), + ("a", 4), ("b", 3)) + } + test("SPARK-12555 - result should not be corrupted after input columns are reordered") { val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e8e801084ffa..d074535bf626 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -620,6 +620,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val df = streaming.join(static, Seq("b")) assert(df.isStreaming, "streaming Dataset returned false for 'isStreaming'.") } + + test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { + val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + // Make sure the generated code for this plan can compile and execute. + checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) + } } case class OtherTuple(_1: String, _2: Int) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dd648cdb8139..cdd404d699a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -89,6 +89,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Function: abcadf not found.") } + test("SPARK-14415: All functions should have own descriptions") { + for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + checkExistence(sql(s"describe function `$f`"), false, "To be added.") + } + } + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), @@ -1819,12 +1827,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e1 = intercept[AnalysisException] { sql("select * from in_valid_table") } - assert(e1.message.contains("Table not found")) + assert(e1.message.contains("Table or View not found")) val e2 = intercept[AnalysisException] { sql("select * from no_db.no_table").show() } - assert(e2.message.contains("Table not found")) + assert(e2.message.contains("Table or View not found")) val e3 = intercept[AnalysisException] { sql("select * from json.invalid_file") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 5dbf61987635..352fd07d0e8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,6 +21,7 @@ import java.util.HashMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.AggregateHashMap @@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X - Join w long codegen=true 275 / 352 76.2 13.1 19.4X + Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X + Join w long codegen=true 321 / 371 65.3 15.3 9.3X */ runBenchmark("Join w long duplicated", N) { @@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X - Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X + Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X + Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X */ val dim2 = broadcast(sqlContext.range(M) @@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X - Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X + Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X + Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X */ val dim3 = broadcast(sqlContext.range(M) @@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X - outer join w long codegen=true 216 / 226 97.2 10.3 26.3X + outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X + outer join w long codegen=true 261 / 276 80.5 12.4 11.7X */ runBenchmark("semi join w long", N) { @@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X - semi join w long codegen=true 211 / 229 99.2 10.1 22.2X + semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X + semi join w long codegen=true 237 / 244 88.3 11.3 8.1X */ } @@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X - shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X + shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X + shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X */ } @@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 10 << 20 + val N = 20 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) - benchmark.addCase("hash") { iter => + benchmark.addCase("UnsafeRowhash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) @@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("murmur3 hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } + } + benchmark.addCase("fast hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 var s = 0 while (i < N) { - key.setInt(0, i % 1000) - val h = Murmur3_x86_32.hashLong(i % 1000, 42) + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) s += h i += 1 } @@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < N) { + val numKeys = 65536 + while (i < numKeys) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, Murmur3_x86_32.hashLong(i % 65536, 42)) - if (loc.isDefined) { - value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - value.setInt(0, value.getInt(0) + 1) - i += 1 - } else { + if (!loc.isDefined) { loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 } } } @@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 112 / 116 93.2 10.7 1.0X - fast hash 65 / 69 160.9 6.2 1.7X - arrayEqual 66 / 69 159.1 6.3 1.7X - Java HashMap (Long) 137 / 182 76.3 13.1 0.8X - Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X - Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X - BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X - BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X - Aggregate HashMap 56 / 62 187.9 5.3 2.0X - */ + UnsafeRow hash 267 / 284 78.4 12.8 1.0X + murmur3 hash 102 / 129 205.5 4.9 2.6X + fast hash 79 / 96 263.8 3.8 3.4X + arrayEqual 164 / 172 128.2 7.8 1.6X + Java HashMap (Long) 321 / 399 65.4 15.3 0.8X + Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X + Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X + LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X + LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X + BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X + BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X + Aggregate HashMap 121 / 131 173.3 5.8 2.2X + */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 9680f3a008a5..17f2343cf971 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode - val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) - val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) + val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) + val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) @@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) - val hashMode = HashedRelationBroadcastMode(true, output, plan.output) + val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchange(hashMode, plan) val hashMode2 = - HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) + HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchange(hashMode2, plan) val exchange4 = ReusedExchange(output, exchange3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 4dc7d3461c9f..c1555114e8b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Properties + import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal @@ -71,6 +73,7 @@ class UnsafeFixedWidthAggregationMapSuite taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, + localProperties = new Properties, metricsSystem = null)) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 476d93fc2a9e..03d4be8ee528 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Properties + import scala.util.Random import org.apache.spark._ @@ -117,6 +119,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, + localProperties = new Properties, metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1f3779373b5d..01687877eeed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} +import java.util.Properties import org.apache.spark._ import org.apache.spark.memory.TaskMemoryManager @@ -113,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) val taskContext = new TaskContextImpl( - 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc)) + 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.createAll(sc)) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8e63b6987650..d6ccaf93488e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,17 +17,25 @@ package org.apache.spark.sql.execution.command +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ +// TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { private val parser = SparkSqlParser + private def assertUnsupported(sql: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + test("create database") { val sql = """ @@ -206,10 +214,12 @@ class DDLCommandSuite extends PlanTest { val parsed_view = parser.parsePlan(sql_view) val expected_table = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None)) + TableIdentifier("new_table_name", None), + isView = false) val expected_view = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None)) + TableIdentifier("new_table_name", None), + isView = true) comparePlans(parsed_table, expected_table) comparePlans(parsed_view, expected_view) } @@ -236,14 +246,14 @@ class DDLCommandSuite extends PlanTest { val tableIdent = TableIdentifier("table_name", None) val expected1_table = AlterTableSetProperties( - tableIdent, Map("test" -> "test", "comment" -> "new_comment")) + tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) val expected2_table = AlterTableUnsetProperties( - tableIdent, Seq("comment", "test"), ifExists = false) + tableIdent, Seq("comment", "test"), ifExists = false, isView = false) val expected3_table = AlterTableUnsetProperties( - tableIdent, Seq("comment", "test"), ifExists = true) - val expected1_view = expected1_table - val expected2_view = expected2_table - val expected3_view = expected3_table + tableIdent, Seq("comment", "test"), ifExists = true, isView = false) + val expected1_view = expected1_table.copy(isView = true) + val expected2_view = expected2_table.copy(isView = true) + val expected3_view = expected3_table.copy(isView = true) comparePlans(parsed1_table, expected1_table) comparePlans(parsed2_table, expected2_table) @@ -327,11 +337,11 @@ class DDLCommandSuite extends PlanTest { Seq( (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), - ifNotExists = true)(sql1) + ifNotExists = true) val expected2 = AlterTableAddPartition( TableIdentifier("table_name", None), Seq((Map("dt" -> "2008-08-08"), Some("loc"))), - ifNotExists = false)(sql2) + ifNotExists = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -370,22 +380,16 @@ class DDLCommandSuite extends PlanTest { val expected = AlterTableRenamePartition( TableIdentifier("table_name", None), Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2008-09-09", "country" -> "uk"))(sql) + Map("dt" -> "2008-09-09", "country" -> "uk")) comparePlans(parsed, expected) } - test("alter table: exchange partition") { - val sql = + test("alter table: exchange partition (not supported)") { + assertUnsupported( """ |ALTER TABLE table_name_1 EXCHANGE PARTITION |(dt='2008-08-08', country='us') WITH TABLE table_name_2 - """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = AlterTableExchangePartition( - TableIdentifier("table_name_1", None), - TableIdentifier("table_name_2", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + """.stripMargin) } // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] @@ -406,7 +410,10 @@ class DDLCommandSuite extends PlanTest { val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") val parsed1_table = parser.parsePlan(sql1_table) - val parsed2_table = parser.parsePlan(sql2_table) + val e = intercept[ParseException] { + parser.parsePlan(sql2_table) + } + assert(e.getMessage.contains("Operation not allowed")) intercept[ParseException] { parser.parsePlan(sql1_view) @@ -421,70 +428,39 @@ class DDLCommandSuite extends PlanTest { Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = true, - purge = false)(sql1_table) - val expected2_table = AlterTableDropPartition( - tableIdent, - Seq( - Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = false, - purge = true)(sql2_table) + ifExists = true) comparePlans(parsed1_table, expected1_table) - comparePlans(parsed2_table, expected2_table) } - test("alter table: archive partition") { - val sql = "ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')" - val parsed = parser.parsePlan(sql) - val expected = AlterTableArchivePartition( - TableIdentifier("table_name", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + test("alter table: archive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')") } - test("alter table: unarchive partition") { - val sql = "ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')" - val parsed = parser.parsePlan(sql) - val expected = AlterTableUnarchivePartition( - TableIdentifier("table_name", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + test("alter table: unarchive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')") } test("alter table: set file format") { - val sql1 = - """ - |ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' - |OUTPUTFORMAT 'test' SERDE 'test' INPUTDRIVER 'test' OUTPUTDRIVER 'test' - """.stripMargin - val sql2 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + + val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + "OUTPUTFORMAT 'test' SERDE 'test'" - val sql3 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + + val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + "SET FILEFORMAT PARQUET" val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) val tableIdent = TableIdentifier("table_name", None) val expected1 = AlterTableSetFileFormat( tableIdent, None, - List("test", "test", "test", "test", "test"), + List("test", "test", "test"), None)(sql1) val expected2 = AlterTableSetFileFormat( - tableIdent, - None, - List("test", "test", "test"), - None)(sql2) - val expected3 = AlterTableSetFileFormat( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), Seq(), - Some("PARQUET"))(sql3) + Some("PARQUET"))(sql2) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) } test("alter table: set location") { @@ -506,55 +482,24 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } - test("alter table: touch") { - val sql1 = "ALTER TABLE table_name TOUCH" - val sql2 = "ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableTouch( - tableIdent, - None)(sql1) - val expected2 = AlterTableTouch( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + test("alter table: touch (not supported)") { + assertUnsupported("ALTER TABLE table_name TOUCH") + assertUnsupported("ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')") } - test("alter table: compact") { - val sql1 = "ALTER TABLE table_name COMPACT 'compaction_type'" - val sql2 = + test("alter table: compact (not supported)") { + assertUnsupported("ALTER TABLE table_name COMPACT 'compaction_type'") + assertUnsupported( """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |COMPACT 'MAJOR' - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableCompact( - tableIdent, - None, - "compaction_type")(sql1) - val expected2 = AlterTableCompact( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")), - "MAJOR")(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') + |COMPACT 'MAJOR' + """.stripMargin) } - test("alter table: concatenate") { - val sql1 = "ALTER TABLE table_name CONCATENATE" - val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableMerge(tableIdent, None)(sql1) - val expected2 = AlterTableMerge( - tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + test("alter table: concatenate (not supported)") { + assertUnsupported("ALTER TABLE table_name CONCATENATE") + assertUnsupported( + "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE") } test("alter table: change column name/type/position/comment") { @@ -665,9 +610,12 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } - test("commands only available in HiveContext") { + test("unsupported operations") { intercept[ParseException] { - parser.parsePlan("DROP TABLE D1.T1") + parser.parsePlan("DROP TABLE tab PURGE") + } + intercept[ParseException] { + parser.parsePlan("DROP TABLE tab FOR REPLICATION('eventid')") } intercept[ParseException] { parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") @@ -682,6 +630,14 @@ class DDLCommandSuite extends PlanTest { |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val") """.stripMargin) } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS (from '1', to '10') + """.stripMargin) + } intercept[ParseException] { parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") } @@ -692,4 +648,52 @@ class DDLCommandSuite extends PlanTest { val parsed = parser.parsePlan(sql) assert(parsed.isInstanceOf[Project]) } + + test("drop table") { + val tableName1 = "db.tab" + val tableName2 = "tab" + + val parsed1 = parser.parsePlan(s"DROP TABLE $tableName1") + val parsed2 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName1") + val parsed3 = parser.parsePlan(s"DROP TABLE $tableName2") + val parsed4 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName2") + + val expected1 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = false, isView = false) + val expected2 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = true, isView = false) + val expected3 = + DropTable(TableIdentifier("tab", None), ifExists = false, isView = false) + val expected4 = + DropTable(TableIdentifier("tab", None), ifExists = true, isView = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } + + test("drop view") { + val viewName1 = "db.view" + val viewName2 = "view" + + val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1") + val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1") + val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2") + val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") + + val expected1 = + DropTable(TableIdentifier("view", Option("db")), ifExists = false, isView = true) + val expected2 = + DropTable(TableIdentifier("view", Option("db")), ifExists = true, isView = true) + val expected3 = + DropTable(TableIdentifier("view", None), ifExists = false, isView = true) + val expected4 = + DropTable(TableIdentifier("view", None), ifExists = true, isView = true) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 7084665b3b80..9ffffa0bdd6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -58,6 +58,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e.getMessage.toLowerCase.contains("operation not allowed")) } + private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { + if (expectException) intercept[AnalysisException] { body } else body + } + private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false) } @@ -320,7 +324,61 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: add partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") + } + + test("alter table: rename partition") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3)) + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') RENAME TO PARTITION (a='100')") + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') RENAME TO PARTITION (b='200')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "100"), Map("b" -> "200"), part3)) + // rename without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 PARTITION (a='100') RENAME TO PARTITION (a='10')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "10"), Map("b" -> "200"), part3)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") + } + // partition to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 PARTITION (x='300') RENAME TO PARTITION (x='333')") + } + } test("show tables") { withTempTable("show1a", "show2b") { @@ -391,6 +449,64 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Nil) } + test("drop table - temporary table") { + val catalog = sqlContext.sessionState.catalog + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) + sql("DROP TABLE tab1") + assert(catalog.listTables("default") == Nil) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + + private def testDropTable(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listTables("dbx") == Seq(tableIdent)) + sql("DROP TABLE dbx.tab1") + assert(catalog.listTables("dbx") == Nil) + sql("DROP TABLE IF EXISTS dbx.tab1") + // no exception will be thrown + sql("DROP TABLE dbx.tab1") + } + + test("drop view in SQLContext") { + // SQLContext does not support create view. Log an error message, if tab1 does not exists + sql("DROP VIEW tab1") + + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.listTables("dbx") == Seq(tableIdent)) + + val e = intercept[AnalysisException] { + sql("DROP VIEW dbx.tab1") + } + assert( + e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) + } + private def convertToDatasourceTable( catalog: SessionCatalog, tableIdent: TableIdentifier): Unit = { @@ -429,10 +545,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(storageFormat.locationUri === Some(expected)) } } - // Optionally expect AnalysisException - def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { - if (expectException) intercept[AnalysisException] { body } else body - } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") verifyLocation("/path/to/your/lovely/heart") @@ -506,4 +618,102 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + + "PARTITION (b='2') LOCATION 'paris' PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Some("paris")) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + } + // add partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (d='4')") + } + // partition to add already exists + intercept[AnalysisException] { + sql("ALTER TABLE tab1 ADD PARTITION (d='4')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + } + + private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + createTablePartition(catalog, part4, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (d='4'), PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) + } + // drop partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (b='2')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (b='2')") + } + // partition to drop does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 DROP PARTITION (x='300')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (x='300')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 41f536fc372a..dac56d393647 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -22,8 +22,6 @@ import java.io.File import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.Job -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} @@ -34,8 +32,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ @@ -196,6 +193,34 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) } + test("partitioned table - case insensitive") { + withSQLConf("spark.sql.caseSensitive" -> "false") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("P1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + } + test("partitioned table - after scan filters") { val table = createTable( @@ -365,18 +390,6 @@ class TestFileFormat extends FileFormat { throw new NotImplementedError("JUST FOR TESTING") } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - throw new NotImplementedError("JUST FOR TESTING") - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2a18acb95b5e..e17340c70b7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1664,4 +1664,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("wide nested json table") { + val nested = (1 to 100).map { i => + s""" + |"c$i": $i + """.stripMargin + }.mkString(", ") + val json = s""" + |{"a": [{$nested}], "b": [{$nested}]} + """.stripMargin + val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) + val df = sqlContext.read.json(rdd) + assert(df.schema.size === 2) + df.collect() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ed87a9943952..371a9ed617d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { + val mm = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()) + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } - test("LongArrayRelation") { + test("LongToUnsafeRowMap") { val unsafeProj = UnsafeProjection.create( Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) - val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) - assert(longRelation.isInstanceOf[LongArrayRelation]) - val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] + val key = Seq(BoundReference(0, IntegerType, false)) + val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) + assert(longRelation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) + val row = longRelation.getValue(i) assert(row.getInt(0) === i) assert(row.getInt(1) === i + 1) } + val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) + assert(!longRelation2.keyIsUnique) + (0 until 100).foreach { i => + val rows = longRelation2.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - longArrayRelation.writeExternal(out) + longRelation2.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) - val relation = new LongArrayRelation() + val relation = new LongHashedRelation() relation.readExternal(in) + assert(!relation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) - assert(row.getInt(0) === i) - assert(row.getInt(1) === i + 1) + val rows = relation.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8a551cd78ca5..31b63f2ce13d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -612,23 +612,20 @@ class ColumnarBatchSuite extends SparkFunSuite { val a2 = r2.getList(v._2).toArray assert(a1.length == a2.length, "Seed = " + seed) childType match { - case DoubleType => { + case DoubleType => var i = 0 while (i < a1.length) { assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]), "Seed = " + seed) i += 1 } - } - case FloatType => { + case FloatType => var i = 0 while (i < a1.length) { assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), "Seed = " + seed) i += 1 } - } - case t: DecimalType => var i = 0 while (i < a1.length) { @@ -640,7 +637,6 @@ class ColumnarBatchSuite extends SparkFunSuite { } i += 1 } - case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 28c558208f6b..00efe21d39de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ @@ -31,22 +32,50 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils object LastOptions { + + var mockStreamSourceProvider = mock(classOf[StreamSourceProvider]) + var mockStreamSinkProvider = mock(classOf[StreamSinkProvider]) var parameters: Map[String, String] = null var schema: Option[StructType] = null var partitionColumns: Seq[String] = Nil + + def clear(): Unit = { + parameters = null + schema = null + partitionColumns = null + reset(mockStreamSourceProvider) + reset(mockStreamSinkProvider) + } } /** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + LastOptions.parameters = parameters + LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + ("dummySource", fakeSchema) + } + override def createSource( sqlContext: SQLContext, + metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { LastOptions.parameters = parameters LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.createSource( + sqlContext, metadataPath, schema, providerName, parameters) new Source { - override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) @@ -64,6 +93,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { partitionColumns: Seq[String]): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns + LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } @@ -117,7 +147,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(LastOptions.parameters("opt2") == "2") assert(LastOptions.parameters("opt3") == "3") - LastOptions.parameters = null + LastOptions.clear() df.write .format("org.apache.spark.sql.streaming.test") @@ -181,7 +211,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(LastOptions.parameters("path") == "/test") - LastOptions.parameters = null + LastOptions.clear() df.write .format("org.apache.spark.sql.streaming.test") @@ -204,7 +234,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(LastOptions.parameters("boolOpt") == "false") assert(LastOptions.parameters("doubleOpt") == "6.7") - LastOptions.parameters = null + LastOptions.clear() df.write .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) @@ -303,4 +333,39 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) } + + test("source metadataPath") { + LastOptions.clear() + + val checkpointLocation = newMetadataDir + + val df1 = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val df2 = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val q = df1.union(df2).write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", checkpointLocation) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + verify(LastOptions.mockStreamSourceProvider).createSource( + sqlContext, + checkpointLocation + "/sources/0", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + + verify(LastOptions.mockStreamSourceProvider).createSource( + sqlContext, + checkpointLocation + "/sources/1", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 09daa7f81a97..73d1b1b1d507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -63,6 +63,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { format: String, path: String, schema: Option[StructType] = None): FileStreamSource = { + val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath val reader = if (schema.isDefined) { sqlContext.read.format(format).schema(schema.get) @@ -72,7 +73,8 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { reader.stream(path) .queryExecution.analyzed .collect { case StreamingRelation(dataSource, _, _) => - dataSource.createSource().asInstanceOf[FileStreamSource] + // There is only one source in our tests so just set sourceId to 0 + dataSource.createSource(s"$checkpointLocation/sources/0").asInstanceOf[FileStreamSource] }.head } @@ -98,9 +100,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } df.queryExecution.analyzed .collect { case StreamingRelation(dataSource, _, _) => - dataSource.createSource().asInstanceOf[FileStreamSource] - }.head - .schema + dataSource.sourceSchema() + }.head._2 } test("FileStreamSource schema: no path") { @@ -340,7 +341,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { Utils.deleteRecursively(src) Utils.deleteRecursively(tmp) } - } class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 5249aa28dd6c..1f2834054519 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -59,7 +59,7 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { } test("error if attempting to resume specific checkpoint") { - val location = Utils.createTempDir("steaming.checkpoint").getCanonicalPath + val location = Utils.createTempDir(namePrefix = "steaming.checkpoint").getCanonicalPath val input = MemoryStream[Int] val query = input.toDF().write diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index e4ea55552691..2bd27c7efdbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -115,8 +115,17 @@ class StreamSuite extends StreamTest with SharedSQLContext { */ class FakeDefaultSource extends StreamSourceProvider { + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) + override def createSource( sqlContext: SQLContext, + metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index e93b0c145fd6..eb49eabcb1ba 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -162,7 +162,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" @@ -172,7 +172,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "OK" + -> "" ) } @@ -187,7 +187,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "USE hive_test_db;" -> "", "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test" ) @@ -210,9 +210,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin - -> "OK", + -> "", "CREATE TABLE sourceTable (key INT, val STRING);" - -> "OK", + -> "", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" @@ -220,9 +220,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" - -> "OK", + -> "", "DROP TABLE sourceTable;" - -> "OK" + -> "" ) } @@ -230,7 +230,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" - -> "Error in query: Table not found: nonexistent_table;" + -> "Error in query: Table or View not found: nonexistent_table;" ) } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9e3cb18d457c..989e68aebed9 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -366,17 +366,92 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "sort_merge_join_desc_6", "sort_merge_join_desc_7", + // These tests try to create a table with bucketed columns, which we don't support + "auto_join32", + "auto_join_filters", + "auto_smb_mapjoin_14", + "ct_case_insensitive", + "explain_rearrange", + "groupby_sort_10", + "groupby_sort_2", + "groupby_sort_3", + "groupby_sort_4", + "groupby_sort_5", + "groupby_sort_7", + "groupby_sort_8", + "groupby_sort_9", + "groupby_sort_test_1", + "inputddl4", + "join_filters", + "join_nulls", + "join_nullsafe", + "load_dyn_part2", + "orc_empty_files", + "reduce_deduplicate", + "smb_mapjoin9", + "smb_mapjoin_1", + "smb_mapjoin_10", + "smb_mapjoin_13", + "smb_mapjoin_14", + "smb_mapjoin_15", + "smb_mapjoin_16", + "smb_mapjoin_17", + "smb_mapjoin_2", + "smb_mapjoin_21", + "smb_mapjoin_25", + "smb_mapjoin_3", + "smb_mapjoin_4", + "smb_mapjoin_5", + "smb_mapjoin_6", + "smb_mapjoin_7", + "smb_mapjoin_8", + "sort_merge_join_desc_1", + "sort_merge_join_desc_2", + "sort_merge_join_desc_3", + "sort_merge_join_desc_4", + + // These tests try to create a table with skewed columns, which we don't support + "create_skewed_table1", + "skewjoinopt13", + "skewjoinopt18", + "skewjoinopt9", + + // This test tries to create a table like with TBLPROPERTIES clause, which we don't support. + "create_like_tbl_props", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", "alter_index", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", // Macro commands are not supported "macro", // Create partitioned view is not supported "create_like_view", - "describe_formatted_view_partitioned" + "describe_formatted_view_partitioned", + + // This uses CONCATENATE, which we don't support + "alter_merge_2", + + // TOUCH is not supported + "touch" ) /** @@ -392,7 +467,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter2", "alter3", "alter5", - "alter_merge_2", "alter_partition_format_loc", "alter_partition_with_whitelist", "alter_rename_partition", @@ -430,33 +504,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "auto_join3", "auto_join30", "auto_join31", - "auto_join32", "auto_join4", "auto_join5", "auto_join6", "auto_join7", "auto_join8", "auto_join9", - "auto_join_filters", "auto_join_nulls", "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", "binary_constant", "binarysortable_1", "cast1", @@ -485,15 +540,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "count", "cp_mj_rc", "create_insert_outputformat", - "create_like_tbl_props", "create_nested_type", - "create_skewed_table1", "create_struct_table", "create_view_translate", "cross_join", "cross_product_check_1", "cross_product_check_2", - "ct_case_insensitive", "database_drop", "database_location", "database_properties", @@ -529,7 +581,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_distributeby1", "escape_orderby1", "escape_sortby1", - "explain_rearrange", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", @@ -584,16 +635,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_neg_float", "groupby_ppd", "groupby_ppr", - "groupby_sort_10", - "groupby_sort_2", - "groupby_sort_3", - "groupby_sort_4", - "groupby_sort_5", "groupby_sort_6", - "groupby_sort_7", - "groupby_sort_8", - "groupby_sort_9", - "groupby_sort_test_1", "having", "implicit_cast1", "index_serde", @@ -648,7 +690,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl1", "inputddl2", "inputddl3", - "inputddl4", "inputddl6", "inputddl7", "inputddl8", @@ -704,11 +745,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_array", "join_casesensitive", "join_empty", - "join_filters", "join_hive_626", "join_map_ppr", - "join_nulls", - "join_nullsafe", "join_rc", "join_reorder2", "join_reorder3", @@ -732,7 +770,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part13", "load_dyn_part14", "load_dyn_part14_win", - "load_dyn_part2", "load_dyn_part3", "load_dyn_part4", "load_dyn_part5", @@ -785,7 +822,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nullscript", "optional_outer", "orc_dictionary_threshold", - "orc_empty_files", "order", "order2", "outer_join_ppr", @@ -841,7 +877,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "rcfile_null_value", "rcfile_toleratecorruptions", "rcfile_union", - "reduce_deduplicate", "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", "reduce_deduplicate_extended", @@ -862,31 +897,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_functions", "show_partitions", "show_tblproperties", - "skewjoinopt13", - "skewjoinopt18", - "skewjoinopt9", - "smb_mapjoin9", - "smb_mapjoin_1", - "smb_mapjoin_10", - "smb_mapjoin_13", - "smb_mapjoin_14", - "smb_mapjoin_15", - "smb_mapjoin_16", - "smb_mapjoin_17", - "smb_mapjoin_2", - "smb_mapjoin_21", - "smb_mapjoin_25", - "smb_mapjoin_3", - "smb_mapjoin_4", - "smb_mapjoin_5", - "smb_mapjoin_6", - "smb_mapjoin_7", - "smb_mapjoin_8", "sort", - "sort_merge_join_desc_1", - "sort_merge_join_desc_2", - "sort_merge_join_desc_3", - "sort_merge_join_desc_4", "stats0", "stats_aggregator_error_1", "stats_empty_partition", @@ -897,7 +908,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "touch", "transform_ppr1", "transform_ppr2", "truncate_table", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b1156fb3e259..ff52e6ad74c3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -25,7 +25,6 @@ import org.apache.thrift.TException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchItemException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.hive.client.HiveClient @@ -66,8 +65,6 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat try { body } catch { - case e: NoSuchItemException => - throw new AnalysisException(e.getMessage) case NonFatal(e) if isClientException(e) => throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) } @@ -182,6 +179,10 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.getTable(db, table) } + override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient { + client.getTableOption(db, table) + } + override def tableExists(db: String, table: String): Boolean = withClient { client.getTableOption(db, table).isDefined } @@ -196,6 +197,11 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.listTables(db, pattern) } + override def showCreateTable(db: String, table: String): String = withClient { + require(this.tableExists(db, table), s"The table $db.$table does not exist.") + client.showCreateTable(db, table) + } + // -------------------------------------------------------------------------- // Partitions // -------------------------------------------------------------------------- @@ -215,26 +221,7 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = withClient { requireTableExists(db, table) - // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we - // need to implement it here ourselves. This is currently somewhat expensive because - // we make multiple synchronous calls to Hive for each partition we want to drop. - val partsToDrop = - if (ignoreIfNotExists) { - parts.filter { spec => - try { - getPartition(db, table, spec) - true - } catch { - // Filter out the partitions that do not actually exist - case _: AnalysisException => false - } - } - } else { - parts - } - if (partsToDrop.nonEmpty) { - client.dropPartitions(db, table, partsToDrop) - } + client.dropPartitions(db, table, parts, ignoreIfNotExists) } override def renamePartitions( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 589862c7c02e..585befe37825 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -450,9 +450,7 @@ private[hive] trait HiveInspectors { if (o != null) { val array = o.asInstanceOf[ArrayData] val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(elementType, (_, e) => { - values.add(wrapper(e)) - }) + array.foreach(elementType, (_, e) => values.add(wrapper(e))) values } else { null @@ -468,9 +466,8 @@ private[hive] trait HiveInspectors { if (o != null) { val map = o.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(keyWrapper(k), valueWrapper(v)) - }) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) jmap } else { null @@ -587,9 +584,9 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => { + a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => list.add(wrap(e, x.getListElementObjectInspector, tpe)) - }) + ) list case x: MapObjectInspector => val keyType = dataType.asInstanceOf[MapType].keyType @@ -599,10 +596,10 @@ private[hive] trait HiveInspectors { // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { + map.foreach(keyType, valueType, (k, v) => hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), wrap(v, x.getMapValueObjectInspector, valueType)) - }) + ) hashMap } @@ -704,9 +701,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[ArrayData].foreach(dt, (_, e) => { - list.add(wrap(e, listObjectInspector, dt)) - }) + value.asInstanceOf[ArrayData].foreach(dt, (_, e) => + list.add(wrap(e, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -718,9 +714,8 @@ private[hive] trait HiveInspectors { val map = value.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { - jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)) - }) + map.foreach(keyType, valueType, (k, v) => + jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType))) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 14f331961ef4..ccc8345d7375 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -91,7 +91,7 @@ private[hive] object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), "avro" -> HiveSerDe( @@ -905,8 +905,13 @@ private[hive] case class MetastoreRelation( val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(toHiveColumn).asJava) - tTable.setPartitionKeys(table.partitionColumns.map(toHiveColumn).asJava) + + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + sd.setCols(schema.asJava) + tTable.setPartitionKeys(partCols.asJava) table.storage.locationUri.foreach(sd.setLocation) table.storage.inputFormat.foreach(sd.setInputFormat) @@ -1013,7 +1018,10 @@ private[hive] case class MetastoreRelation( val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.schema.map(_.toAttribute) + // TODO: just make this hold the schema itself, not just non-partition columns + val attributes = table.schema + .filter { c => !table.partitionColumnNames.contains(c.name) } + .map(_.toAttribute) val output = attributes ++ partitionKeys diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0cccc22e5a62..f0544e281a38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,22 +26,22 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, FunctionResourceLoader, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.Utils - private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, @@ -103,13 +103,13 @@ private[sql] class HiveSessionCatalog( } def createDataSourceTable( - name: TableIdentifier, - userSpecifiedSchema: Option[StructType], - partitionColumns: Array[String], - bucketSpec: Option[BucketSpec], - provider: String, - options: Map[String, String], - isExternal: Boolean): Unit = { + name: TableIdentifier, + userSpecifiedSchema: Option[StructType], + partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], + provider: String, + options: Map[String, String], + isExternal: Boolean): Unit = { metastoreCatalog.createDataSourceTable( name, userSpecifiedSchema, partitionColumns, bucketSpec, provider, options, isExternal) } @@ -155,7 +155,7 @@ private[sql] class HiveSessionCatalog( new HiveFunctionWrapper(clazz.getName), children, isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. + udaf.dataType // Force it to check input data types. udaf } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) @@ -242,7 +242,6 @@ private[sql] class HiveSessionCatalog( createTempFunction(functionName, info, builder, ignoreIfExists = false) } } - private[sql] object HiveSessionCatalog { // This is the list of Hive's built-in functions that are commonly used and we want to // pre-load when we create the FunctionRegistry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index f44937ec6f98..010361a32eb3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _} -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, - DescribeCommand} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, CreateTempTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ private[hive] trait HiveStrategies { @@ -90,6 +89,11 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsSelect if c.temporary => + val cmd = CreateTempTableUsingAsSelect( + c.tableIdent, c.provider, c.partitionColumns, c.mode, c.options, c.child) + ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsSelect => val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, c.bucketSpec, c.mode, c.options, c.child) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 94794b1572ee..7280ebc6302f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -120,16 +120,13 @@ private[hive] trait HiveClient { ignoreIfExists: Boolean): Unit /** - * Drop one or many partitions in the given table. - * - * Note: Unfortunately, Hive does not currently provide a way to ignore this call if the - * partitions do not already exist. The seemingly relevant flag `ifExists` in - * [[org.apache.hadoop.hive.metastore.PartitionDropOptions]] is not read anywhere. + * Drop one or many partitions in the given table, assuming they exist. */ def dropPartitions( db: String, table: String, - specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit /** * Rename one or many existing table partitions, assuming they exist. @@ -252,4 +249,5 @@ private[hive] trait HiveClient { /** Used for testing only. Removes all metadata from this instance of Hive. */ def reset(): Unit + def showCreateTable(db: String, table: String): String } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index d0eb9ddf50ae..ed1ad41384c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -26,11 +26,13 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.{PartitionDropOptions, TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException} +import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState @@ -43,7 +45,9 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{CircularBuffer, Utils} /** @@ -299,6 +303,10 @@ private[hive] class HiveClientImpl( tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") Option(client.getTable(dbName, tableName, false)).map { h => + // Note: Hive separates partition columns and the schema, but for us the + // partition columns are part of the schema + val partCols = h.getPartCols.asScala.map(fromHiveColumn) + val schema = h.getCols.asScala.map(fromHiveColumn) ++ partCols CatalogTable( identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { @@ -307,9 +315,10 @@ private[hive] class HiveClientImpl( case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX_TABLE case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIRTUAL_VIEW }, - schema = h.getCols.asScala.map(fromHiveColumn), - partitionColumns = h.getPartCols.asScala.map(fromHiveColumn), - sortColumns = Seq(), + schema = schema, + partitionColumnNames = partCols.map(_.name), + sortColumnNames = Seq(), // TODO: populate this + bucketColumnNames = h.getBucketCols.asScala, numBuckets = h.getNumBuckets, createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, @@ -367,9 +376,25 @@ private[hive] class HiveClientImpl( override def dropPartitions( db: String, table: String, - specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit = withHiveState { + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withHiveState { // TODO: figure out how to drop multiple partitions in one call - specs.foreach { s => client.dropPartition(db, table, s.values.toList.asJava, true) } + val hiveTable = client.getTable(db, table, true /* throw exception */) + specs.foreach { s => + // The provided spec here can be a partial spec, i.e. it will match all partitions + // whose specs are supersets of this partial spec. E.g. If a table has partitions + // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will match both. + val matchingParts = client.getPartitions(hiveTable, s.asJava).asScala + if (matchingParts.isEmpty && !ignoreIfNotExists) { + throw new AnalysisException( + s"partition to drop '$s' does not exist in table '$table' database '$db'") + } + matchingParts.foreach { hivePartition => + val dropOptions = new PartitionDropOptions + dropOptions.ifExists = ignoreIfNotExists + client.dropPartition(db, table, hivePartition.getValues, dropOptions) + } + } } override def renamePartitions( @@ -604,11 +629,217 @@ private[hive] class HiveClientImpl( } } + override def showCreateTable(db: String, tableName: String): String = withHiveState { + Option(client.getTable(db, tableName, false)).map { hiveTable => + val tblProperties = hiveTable.getParameters.asScala.toMap + val duplicateProps = scala.collection.mutable.ArrayBuffer.empty[String] + + if (tblProperties.get("spark.sql.sources.provider").isDefined) { + generateDataSourceDDL(hiveTable) + } else { + generateHiveDDL(hiveTable) + } + }.get + } + /* -------------------------------------------------------- * | Helper methods for converting to and from Hive classes | * -------------------------------------------------------- */ + private def generateCreateTableHeader( + hiveTable: HiveTable, + processedProps: scala.collection.mutable.ArrayBuffer[String]): String = { + val sb = new StringBuilder("CREATE ") + if(hiveTable.isTemporary) { + sb.append("TEMPORARY ") + } + if (hiveTable.getTableType == HiveTableType.EXTERNAL_TABLE) { + processedProps += "EXTERNAL" + sb.append("EXTERNAL TABLE " + + quoteIdentifier(hiveTable.getDbName) + "." + quoteIdentifier(hiveTable.getTableName)) + } else { + sb.append("TABLE " + + quoteIdentifier(hiveTable.getDbName) + "." + quoteIdentifier(hiveTable.getTableName)) + } + sb.toString() + } + + private def generateColsDataSource( + hiveTable: HiveTable, + processedProps: scala.collection.mutable.ArrayBuffer[String]): String = { + val schemaStringFromParts: Option[String] = { + val props = hiveTable.getParameters.asScala + props.get("spark.sql.sources.schema.numParts").map { numParts => + val parts = (0 until numParts.toInt).map { index => + val part = props.get(s"spark.sql.sources.schema.part.$index").orNull + if (part == null) { + throw new AnalysisException( + "Could not read schema from the metastore because it is corrupted " + + s"(missing part $index of the schema, $numParts parts are expected).") + } + part + } + // Stick all parts back to a single schema string. + parts.mkString + } + } + + if (schemaStringFromParts.isDefined) { + (schemaStringFromParts.map(s => DataType.fromJson(s).asInstanceOf[StructType]). + get map { f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}" }) + .mkString("( ", ", ", " )") + } else { + "" + } + } + + private def generateDataSourceDDL(hiveTable: HiveTable): String = { + val processedProperties = scala.collection.mutable.ArrayBuffer.empty[String] + val sb = new StringBuilder(generateCreateTableHeader(hiveTable, processedProperties)) + // It is possible that the column list returned from hive metastore is just a dummy + // one, such as "col array", because the metastore was created as spark sql + // specific metastore (refer to HiveMetaStoreCatalog.createDataSourceTable. + // newSparkSQLSpecificMetastoreTable). In such case, the column schema information + // is located in tblproperties in json format. + sb.append(generateColsDataSource(hiveTable, processedProperties) + "\n") + sb.append("USING " + hiveTable.getProperty("spark.sql.sources.provider") + "\n") + val options = scala.collection.mutable.ArrayBuffer.empty[String] + hiveTable.getSd.getSerdeInfo.getParameters.asScala.foreach { e => + options += "" + escapeHiveCommand(e._1) + " '" + escapeHiveCommand(e._2) + "'" + } + if (options.size > 0) { + sb.append("OPTIONS " + options.mkString("( ", ", \n", " )")) + } + // create table using syntax does not support EXTERNAL keyword + sb.toString.replace("EXTERNAL TABLE", "TABLE") + } + + private def generateHiveDDL(hiveTable: HiveTable): String = { + val sb = new StringBuilder + val tblProperties = hiveTable.getParameters.asScala.toMap + val duplicateProps = scala.collection.mutable.ArrayBuffer.empty[String] + + if (hiveTable.getTableType == HiveTableType.VIRTUAL_VIEW) { + sb.append("CREATE VIEW " + quoteIdentifier(hiveTable.getDbName) + "." + + quoteIdentifier(hiveTable.getTableName) + " AS " + hiveTable.getViewOriginalText) + } else { + // create table header + sb.append(generateCreateTableHeader(hiveTable, duplicateProps)) + + // column list + sb.append( + hiveTable.getCols.asScala.map(fromHiveColumn).map { col => + quoteIdentifier(col.name) + " " + col.dataType + ( col.comment.getOrElse("") match { + case cmt: String if cmt.length > 0 => + " COMMENT '" + escapeHiveCommand(cmt) + "'" + case _ => "" + }) + }.mkString("( ", ", ", " ) \n")) + + // partition + val partCols = hiveTable.getPartitionKeys + if (partCols != null && partCols.size() > 0) { + sb.append("PARTITIONED BY ") + sb.append( + partCols.asScala.map(fromHiveColumn).map { col => + quoteIdentifier(col.name) + " " + col.dataType + (col.comment.getOrElse("") match { + case cmt: String if cmt.length > 0 => + " COMMENT '" + escapeHiveCommand(cmt) + "'" + case _ => "" + }) + }.mkString("( ", ", ", " )\n")) + } + + // sort bucket + val bucketCols = hiveTable.getBucketCols + if (bucketCols != null && bucketCols.size() > 0) { + sb.append("CLUSTERED BY ") + sb.append(bucketCols.asScala.map(quoteIdentifier(_)).mkString("( ", ", ", " ) \n")) + // SORTing columns + val sortCols = hiveTable.getSortCols + if (sortCols != null && sortCols.size() > 0) { + sb.append("SORTED BY ") + sb.append( + sortCols.asScala.map { col => + quoteIdentifier(col.getCol) + " " + (col.getOrder match { + case o if o == BaseSemanticAnalyzer.HIVE_COLUMN_ORDER_DESC => "DESC" + case _ => "ASC" + }) + }.mkString("( ", ", ", " ) \n")) + } + if (hiveTable.getNumBuckets > 0) { + sb.append("INTO " + hiveTable.getNumBuckets + " BUCKETS \n") + } + } + + // skew spec + val skewCols = hiveTable.getSkewedColNames + if (skewCols != null && skewCols.size() > 0) { + sb.append("SKEWED BY ") + sb.append(skewCols.asScala.map(quoteIdentifier(_)).mkString("( ", ", ", " ) \n")) + val skewColValues = hiveTable.getSkewedColValues + sb.append("ON ") + sb.append(skewColValues.asScala.map { values => + values.asScala.map("" + _).mkString("(", ", ", ")") + }.mkString("(", ", ", ")\n")) + if (hiveTable.isStoredAsSubDirectories) { + sb.append("STORED AS DIRECTORIES\n") + } + } + + // ROW FORMAT + val storageHandler = hiveTable.getStorageHandler + val serdeProps = hiveTable.getSd.getSerdeInfo.getParameters.asScala + sb.append("ROW FORMAT SERDE '" + escapeHiveCommand(hiveTable.getSerializationLib) + "'\n") + if (storageHandler == null) { + sb.append("WITH SERDEPROPERTIES ") + sb.append(serdeProps.map { serdeProp => + "'" + escapeHiveCommand(serdeProp._1) + "'='" + escapeHiveCommand(serdeProp._2) + "'" + }.mkString("( ", ", ", " )\n")) + + sb.append("STORED AS INPUTFORMAT '" + + escapeHiveCommand(hiveTable.getInputFormatClass.getName) + "' \n") + + sb.append("OUTPUTFORMAT '" + + escapeHiveCommand(hiveTable.getOutputFormatClass.getName) + "' \n") + } else { + // storage handler case + duplicateProps += hive_metastoreConstants.META_TABLE_STORAGE + sb.append("STORED BY '" + escapeHiveCommand( + tblProperties.getOrElse(hive_metastoreConstants.META_TABLE_STORAGE, "")) + "'\n") + sb.append("WITH SERDEPROPERTIES ") + sb.append(serdeProps.map { serdeProp => + "'" + escapeHiveCommand(serdeProp._1) + "'='" + escapeHiveCommand(serdeProp._2) + "'" + }.mkString("( ", ", ", " )\n")) + } + + // table location + sb.append("LOCATION '" + + escapeHiveCommand(shim.getDataLocation(hiveTable).get) + "' \n") + + // table properties + val propertPairs = hiveTable.getParameters.asScala.collect { + case (k, v) if !duplicateProps.contains(k) => + "'" + escapeHiveCommand(k) + "'='" + escapeHiveCommand(v) + "'" + } + if (propertPairs.size>0) { + sb.append("TBLPROPERTIES " + propertPairs.mkString("( ", ", \n", " )") + "\n") + } + } + sb.toString() + } + + private def escapeHiveCommand(str: String): String = { + str.map{c => + if (c == '\'' || c == ';') { + '\\' + } else { + c + } + } + } + private def toInputFormat(name: String) = Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] @@ -659,15 +890,37 @@ private[hive] class HiveClientImpl( private def toHiveTable(table: CatalogTable): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) + // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. + // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. + // (metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) hiveTable.setTableType(table.tableType match { - case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE - case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE + case CatalogTableType.EXTERNAL_TABLE => + hiveTable.setProperty("EXTERNAL", "TRUE") + HiveTableType.EXTERNAL_TABLE + case CatalogTableType.MANAGED_TABLE => + HiveTableType.MANAGED_TABLE case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW }) - hiveTable.setFields(table.schema.map(toHiveColumn).asJava) - hiveTable.setPartCols(table.partitionColumns.map(toHiveColumn).asJava) + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + if (table.schema.isEmpty) { + // This is a hack to preserve existing behavior. Before Spark 2.0, we do not + // set a default serde here (this was done in Hive), and so if the user provides + // an empty schema Hive would automatically populate the schema with a single + // field "col". However, after SPARK-14388, we set the default serde to + // LazySimpleSerde so this implicit behavior no longer happens. Therefore, + // we need to do it in Spark ourselves. + hiveTable.setFields( + Seq(new FieldSchema("col", "array", "from deserializer")).asJava) + } else { + hiveTable.setFields(schema.asJava) + } + hiveTable.setPartCols(partCols.asJava) // TODO: set sort columns here too + hiveTable.setBucketCols(table.bucketColumnNames.asJava) hiveTable.setOwner(conf.getUser) hiveTable.setNumBuckets(table.numBuckets) hiveTable.setCreateTime((table.createTime / 1000).toInt) @@ -675,9 +928,11 @@ private[hive] class HiveClientImpl( table.storage.locationUri.foreach { loc => shim.setDataLocation(hiveTable, loc) } table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) - table.storage.serde.foreach(hiveTable.setSerializationLib) + hiveTable.setSerializationLib( + table.storage.serde.getOrElse("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) table.storage.serdeProperties.foreach { case (k, v) => hiveTable.setSerdeParam(k, v) } table.properties.foreach { case (k, v) => hiveTable.setProperty(k, v) } + table.comment.foreach { c => hiveTable.setProperty("comment", c) } table.viewOriginalText.foreach { t => hiveTable.setViewOriginalText(t) } table.viewText.foreach { t => hiveTable.setViewExpandedText(t) } hiveTable @@ -708,5 +963,4 @@ private[hive] class HiveClientImpl( serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) } - } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index ab69d3502e93..a97b65e27bc5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -26,15 +26,14 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder -import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} -import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveSerDe} +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} +import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView, HiveSerDe} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** @@ -103,19 +102,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } - /** - * Create a [[DropTable]] command. - */ - override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.PURGE != null) { - logWarning("PURGE option is ignored.") - } - if (ctx.REPLICATION != null) { - logWarning("REPLICATION clause is ignored.") - } - DropTable(visitTableIdentifier(ctx.tableIdentifier).toString, ctx.EXISTS != null) - } - /** * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other * options are passed on to Hive) e.g.: @@ -134,85 +120,128 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } /** - * Create a [[CatalogStorageFormat]]. This is part of the [[CreateTableAsSelect]] command. + * Create a [[CatalogStorageFormat]] for creating tables. */ override def visitCreateFileFormat( ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - if (ctx.storageHandler == null) { - typedVisit[CatalogStorageFormat](ctx.fileFormat) - } else { - visitStorageHandler(ctx.storageHandler) + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + visitTableFileFormat(c) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + visitGenericFileFormat(c) + case (null, storageHandler) => + throw new ParseException("Operation not allowed: ... STORED BY storage_handler ...", ctx) + case _ => + throw new ParseException("expected either STORED AS or STORED BY, not both", ctx) } } /** - * Create a [[CreateTableAsSelect]] command. + * Create a table, returning either a [[CreateTable]] or a [[CreateTableAsSelect]]. + * + * This is not used to create datasource tables, which is handled through + * "CREATE TABLE ... USING ...". + * + * Note: several features are currently not supported - temporary tables, bucketing, + * skewed columns and storage handlers (STORED BY). + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) [STORED AS DIRECTORIES]] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} */ - override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = { - if (ctx.query == null) { - HiveNativeCommand(command(ctx)) + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + // TODO: implement temporary tables + if (temp) { + throw new ParseException( + "CREATE TEMPORARY TABLE is not supported yet. " + + "Please use registerTempTable as an alternative.", ctx) + } + if (ctx.skewSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... SKEWED BY ...", ctx) + } + if (ctx.bucketSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... CLUSTERED BY ...", ctx) + } + val tableType = if (external) { + CatalogTableType.EXTERNAL_TABLE } else { - // Get the table header. - val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val tableType = if (external) { - CatalogTableType.EXTERNAL_TABLE - } else { - CatalogTableType.MANAGED_TABLE - } - - // Unsupported clauses. - if (temp) { - logWarning("TEMPORARY clause is ignored.") - } - if (ctx.bucketSpec != null) { - // TODO add this - we need cluster columns in the CatalogTable for this to work. - logWarning("CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause is ignored.") - } - if (ctx.skewSpec != null) { - logWarning("SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause is ignored.") - } - - // Create the schema. - val schema = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns(_, _.toLowerCase)) - - // Get the column by which the table is partitioned. - val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns(_)) - - // Create the storage. - def format(fmt: ParserRuleContext): CatalogStorageFormat = { - Option(fmt).map(typedVisit[CatalogStorageFormat]).getOrElse(EmptyStorageFormat) - } - // Default storage. + CatalogTableType.MANAGED_TABLE + } + val comment = Option(ctx.STRING).map(string) + val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns) + val cols = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns) + val properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val selectQuery = Option(ctx.query).map(plan) + + // Note: Hive requires partition columns to be distinct from the schema, so we need + // to include the partition columns here explicitly + val schema = cols ++ partitionCols + + // Storage format + val defaultStorage: CatalogStorageFormat = { val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - // Defined storage. - val fileStorage = format(ctx.createFileFormat) - val rowStorage = format(ctx.rowFormat) - val storage = CatalogStorageFormat( - Option(ctx.locationSpec).map(visitLocationSpec), - fileStorage.inputFormat.orElse(hiveSerDe.inputFormat), - fileStorage.outputFormat.orElse(hiveSerDe.outputFormat), - rowStorage.serde.orElse(hiveSerDe.serde).orElse(fileStorage.serde), - rowStorage.serdeProperties ++ fileStorage.serdeProperties - ) + val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf) + CatalogStorageFormat( + locationUri = None, + inputFormat = defaultHiveSerde.flatMap(_.inputFormat) + .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), + outputFormat = defaultHiveSerde.flatMap(_.outputFormat) + .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + // Note: Keep this unspecified because we use the presence of the serde to decide + // whether to convert a table created by CTAS to a datasource table. + serde = None, + serdeProperties = Map()) + } + val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + .getOrElse(EmptyStorageFormat) + val rowStorage = Option(ctx.rowFormat).map(visitRowFormat).getOrElse(EmptyStorageFormat) + val location = Option(ctx.locationSpec).map(visitLocationSpec) + val storage = CatalogStorageFormat( + locationUri = location, + inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), + outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), + serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), + serdeProperties = rowStorage.serdeProperties ++ fileStorage.serdeProperties) + + // TODO support the sql text - have a proper location for this! + val tableDesc = CatalogTable( + identifier = name, + tableType = tableType, + storage = storage, + schema = schema, + partitionColumnNames = partitionCols.map(_.name), + properties = properties, + comment = comment) - val tableDesc = CatalogTable( - identifier = table, - tableType = tableType, - schema = schema, - partitionColumns = partitionCols, - storage = storage, - properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), - // TODO support the sql text - have a proper location for this! - viewText = Option(ctx.STRING).map(string)) - CTAS(tableDesc, plan(ctx.query), ifNotExists) + selectQuery match { + case Some(q) => CTAS(tableDesc, q, ifNotExists) + case None => CreateTable(tableDesc, ifNotExists) } } + /** + * Create a [[CreateTableLike]] command. + */ + override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { + val targetTable = visitTableIdentifier(ctx.target) + val sourceTable = visitTableIdentifier(ctx.source) + CreateTableLike(targetTable, sourceTable, ctx.EXISTS != null) + } + /** * Create or replace a view. This creates a [[CreateViewAsSelect]] command. * @@ -229,9 +258,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { if (ctx.identifierList != null) { throw new ParseException(s"Operation not allowed: partitioned views", ctx) } else { - if (ctx.STRING != null) { - logWarning("COMMENT clause is ignored.") - } val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala) val schema = identifiers.map { ic => CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string)) @@ -239,6 +265,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { createView( ctx, ctx.tableIdentifier, + comment = Option(ctx.STRING).map(string), schema, ctx.query, Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), @@ -255,6 +282,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { createView( ctx, ctx.tableIdentifier, + comment = None, Seq.empty, ctx.query, Map.empty, @@ -268,6 +296,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { private def createView( ctx: ParserRuleContext, name: TableIdentifierContext, + comment: Option[String], schema: Seq[CatalogColumn], query: QueryContext, properties: Map[String, String], @@ -281,7 +310,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { storage = EmptyStorageFormat, properties = properties, viewOriginalText = sql, - viewText = sql) + viewText = sql, + comment = comment) CreateView(tableDesc, plan(query), allowExist, replace, command(ctx)) } @@ -296,7 +326,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { recordReader: Token, schemaLess: Boolean): HiveScriptIOSchema = { if (recordWriter != null || recordReader != null) { - logWarning("Used defined record reader/writer classes are currently ignored.") + throw new ParseException( + "Unsupported operation: Used defined record reader/writer classes.", ctx) } // Decode and input/output format. @@ -363,24 +394,19 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty) /** - * Create a [[CatalogStorageFormat]]. The INPUTDRIVER and OUTPUTDRIVER clauses are currently - * ignored. + * Create a [[CatalogStorageFormat]]. */ override def visitTableFileFormat( ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - import ctx._ - if (inDriver != null || outDriver != null) { - logWarning("INPUTDRIVER ... OUTPUTDRIVER ... clauses are ignored.") - } EmptyStorageFormat.copy( - inputFormat = Option(string(inFmt)), - outputFormat = Option(string(outFmt)), - serde = Option(serdeCls).map(string) + inputFormat = Option(string(ctx.inFmt)), + outputFormat = Option(string(ctx.outFmt)), + serde = Option(ctx.serdeCls).map(string) ) } /** - * Resolve a [[HiveSerDe]] based on the format name given. + * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. */ override def visitGenericFileFormat( ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { @@ -397,11 +423,28 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } /** - * Storage Handlers are currently not supported in the statements we support (CTAS). + * Create a [[RowFormat]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} */ - override def visitStorageHandler( - ctx: StorageHandlerContext): CatalogStorageFormat = withOrigin(ctx) { - throw new ParseException("Storage Handlers are currently unsupported.", ctx) + private def visitRowFormat(ctx: RowFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } } /** @@ -444,13 +487,15 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Create a sequence of [[CatalogColumn]]s from a column list */ - private def visitCatalogColumns( - ctx: ColTypeListContext, - formatter: String => String = identity): Seq[CatalogColumn] = withOrigin(ctx) { + private def visitCatalogColumns(ctx: ColTypeListContext): Seq[CatalogColumn] = withOrigin(ctx) { ctx.colType.asScala.map { col => CatalogColumn( - formatter(col.identifier.getText), - col.dataType.getText.toLowerCase, // TODO validate this? + col.identifier.getText.toLowerCase, + // Note: for types like "STRUCT" we can't + // just convert the whole type string to lower case, otherwise the struct field names + // will no longer be case sensitive. Instead, we rely on our parser to get the proper + // case before passing it to Hive. + CatalystSqlParser.parseDataType(col.dataType.getText).simpleString, nullable = true, Option(col.STRING).map(string)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 64d1341a4755..06badff474f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -46,36 +46,6 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { } } -/** - * Drops a table from the metastore and removes it if it is cached. - */ -private[hive] -case class DropTable( - tableName: String, - ifExists: Boolean) extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - val ifExistsClause = if (ifExists) "IF EXISTS " else "" - try { - hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) - } catch { - // This table's metadata is not in Hive metastore (e.g. the table does not exist). - case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => - // Other Throwables can be caused by users providing wrong parameters in OPTIONS - // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}", e) - } - hiveContext.invalidateTable(tableName) - hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.sessionState.catalog.dropTable( - TableIdentifier(tableName), ignoreIfNotExists = true) - Seq.empty[Row] - } -} - private[hive] case class AddJar(path: String) extends RunnableCommand { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 43f445edcb31..21591ec093d3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.{Row, SQLContext} @@ -45,7 +44,6 @@ import org.apache.spark.sql.hive.{HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet private[sql] class DefaultSource extends FileFormat with DataSourceRegister with Serializable { @@ -111,19 +109,6 @@ private[sql] class DefaultSource } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(sqlContext, output, filters, inputFiles).execute() - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index c5f01da4fabb..110c6d19d89b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} class HiveDDLCommandSuite extends PlanTest { @@ -36,11 +37,19 @@ class HiveDDLCommandSuite extends PlanTest { private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { + case CreateTable(desc, allowExisting) => (desc, allowExisting) case CreateTableAsSelect(desc, _, allowExisting) => (desc, allowExisting) case CreateViewAsSelect(desc, _, allowExisting, _, _) => (desc, allowExisting) }.head } + private def assertUnsupported(sql: String): Unit = { + val e = intercept[ParseException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("unsupported")) + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -69,9 +78,13 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("page_url", "string") :: CatalogColumn("referrer_url", "string") :: CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) @@ -115,9 +128,13 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("page_url", "string") :: CatalogColumn("referrer_url", "string") :: CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) @@ -138,10 +155,11 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) assert(desc.storage.serde.isEmpty) assert(desc.properties == Map()) } @@ -173,6 +191,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) @@ -180,6 +199,45 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + test("unsupported operations") { + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TEMPORARY TABLE ctas2 + |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + |STORED AS RCFile + |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |CLUSTERED BY(user_id) INTO 256 BUCKETS + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |SKEWED BY (key) ON (1,5,6) + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' + |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' + |FROM testData + """.stripMargin) + } + } + test("Invalid interval term should throw AnalysisException") { def assertError(sql: String, errorMessage: String): Unit = { val e = intercept[AnalysisException] { @@ -227,7 +285,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use backticks in output of Script Transform") { - val plan = parser.parsePlan( + parser.parsePlan( """SELECT `t`.`thing1` |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t @@ -235,7 +293,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use backticks in output of Generator") { - val plan = parser.parsePlan( + parser.parsePlan( """ |SELECT `gentab2`.`gencol2` |FROM `default`.`src` @@ -245,7 +303,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use escaped backticks in output of Generator") { - val plan = parser.parsePlan( + parser.parsePlan( """ |SELECT `gen``tab2`.`gen``col2` |FROM `default`.`src` @@ -254,6 +312,194 @@ class HiveDDLCommandSuite extends PlanTest { """.stripMargin) } + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.schema == Seq(CatalogColumn("id", "int"), CatalogColumn("name", "string"))) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) + assert(desc.storage.serdeProperties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.message.contains("registerTempTable")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) CLUSTERED BY(id)" + val query1 = s"$baseQuery INTO 10 BUCKETS" + val query2 = s"$baseQuery SORTED BY(id) INTO 10 BUCKETS" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + val e3 = intercept[ParseException] { parser.parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.serdeProperties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc3.storage.serdeProperties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde.isEmpty) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - location") { + val query = "CREATE TABLE my_table (id int, name string) LOCATION '/path/to/mars'" + val (desc, _) = extractTableDesc(query) + assert(desc.storage.locationUri == Some("/path/to/mars")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + + test("create table - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + test("create view -- basic") { val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" val (desc, exists) = extractTableDesc(v1) @@ -277,7 +523,7 @@ class HiveDDLCommandSuite extends PlanTest { """ |CREATE OR REPLACE VIEW IF NOT EXISTS view1 |(col1, col3) - |COMMENT 'I cannot spell' + |COMMENT 'BLABLA' |TBLPROPERTIES('prop1Key'="prop1Val") |AS SELECT * FROM tab1 """.stripMargin @@ -297,6 +543,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.outputFormat.isEmpty) assert(desc.storage.serde.isEmpty) assert(desc.properties == Map("prop1Key" -> "prop1Val")) + assert(desc.comment == Option("BLABLA")) } test("create view -- partitioned view") { @@ -305,4 +552,31 @@ class HiveDDLCommandSuite extends PlanTest { parser.parsePlan(v1).isInstanceOf[HiveNativeCommand] } } + + test("MSCK repair table (not supported)") { + assertUnsupported("MSCK REPAIR TABLE tab1") + } + + test("create table like") { + val v1 = "CREATE TABLE table1 LIKE table2" + val (target, source, exists) = parser.parsePlan(v1).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists == false) + assert(target.database.isEmpty) + assert(target.table == "table1") + assert(source.database.isEmpty) + assert(source.table == "table2") + + val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" + val (target2, source2, exists2) = parser.parsePlan(v2).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists2) + assert(target2.database.isEmpty) + assert(target2.table == "table1") + assert(source2.database.isEmpty) + assert(source2.table == "table2") + } + } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ada8621d0757..8648834f0d88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -88,7 +88,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.partitionColumnNames.isEmpty) assert(hiveTable.tableType === CatalogTableType.MANAGED_TABLE) val columns = hiveTable.schema @@ -151,7 +151,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.partitionColumnNames.isEmpty) assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) val columns = hiveTable.schema diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowDDLSuite.scala new file mode 100644 index 000000000000..3677029e7efe --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowDDLSuite.scala @@ -0,0 +1,451 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class HiveShowDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ + + var jsonFilePath: String = _ + override def beforeAll(): Unit = { + super.beforeAll() + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + } + // there are 3 types of create table to test + // 1. HIVE Syntax: create table t1 (c1 int) partitionedby (c2 int) row format... tblproperties.. + // 2. Spark sql syntx: crate table t1 (c1 int) using .. options (... ) + // 3. saving table from datasource: df.write.format("parquet").saveAsTable("t1") + + /** + * Hive syntax DDL + */ + test("Hive syntax DDL: no row format") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create table t1(c1 int, c2 string) + |stored as parquet + |location '${tmpDir}' + """.stripMargin) + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - partitioned table with column and table comments") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create table t1(c1 bigint, c2 string) + |PARTITIONED BY (c3 int COMMENT 'partition column', c4 string) + |row format delimited fields terminated by ',' + |stored as parquet + |location '${tmpDir}' + |TBLPROPERTIES ('my.property.one'='true', 'my.property.two'='1', + |'my.property.three'='2', 'my.property.four'='false') + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - external, row format and tblproperties") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create external table t1(c1 bigint, c2 string) + |row format delimited fields terminated by ',' + |stored as parquet + |location '${tmpDir}' + |TBLPROPERTIES ('my.property.one'='true', 'my.property.two'='1', + |'my.property.three'='2', 'my.property.four'='false') + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - more row format definition") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create external table t1(c1 int COMMENT 'first column', c2 string) + |COMMENT 'some table' + |PARTITIONED BY (c3 int COMMENT 'partition column', c4 string) + |row format delimited fields terminated by ',' + |COLLECTION ITEMS TERMINATED BY ',' + |MAP KEYS TERMINATED BY ',' + |NULL DEFINED AS 'NaN' + |stored as parquet + |location '${tmpDir}' + |TBLPROPERTIES ('my.property.one'='true', 'my.property.two'='1', + |'my.property.three'='2', 'my.property.four'='false') + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - hive view") { + withTempDir { tmpDir => + withTable("t1") { + withView("v1") { + sql( + s""" + |create table t1(c1 int, c2 string) + |row format delimited fields terminated by ',' + |stored as parquet + |location '${tmpDir}' + """.stripMargin) + sql( + """ + |create view v1 as select * from t1 + """.stripMargin) + assert(compareCatalog( + TableIdentifier("v1"), + sql("show create table v1").collect()(0).toSeq(0).toString)) + } + } + } + } + + test("Hive syntax DDL - SERDE") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create table t1(c1 int, c2 string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |WITH SERDEPROPERTIES ('mapkey.delim'=',', 'field.delim'=',') + |STORED AS + |INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + |OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |location '${tmpDir}' + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + // hive native DDL that is not supported by Spark SQL DDL + test("Hive syntax DDL -- CLUSTERED") { + withTable("t1") { + HiveNativeCommand( + s""" + |create table t1 (c1 int, c2 string) + |clustered by (c1) sorted by (c2 desc) into 5 buckets + """.stripMargin).run(sqlContext) + val ddl = sql("show create table t1").collect() + assert(ddl(0).toSeq(0).toString.contains("WARN")) + assert(ddl(1).toSeq(0).toString + .contains("CLUSTERED BY ( `c1` ) \nSORTED BY ( `c2` DESC ) \nINTO 5 BUCKETS")) + } + } + + test("Hive syntax DDL -- STORED BY") { + withTable("tmp_showcrt1") { + HiveNativeCommand( + s""" + |CREATE EXTERNAL TABLE tmp_showcrt1 (key string, value boolean) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' + |STORED BY 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' + |WITH SERDEPROPERTIES ('field.delim'=',', 'serialization.format'=',') + """.stripMargin).run(sqlContext) + val ddl = sql("show create table tmp_showcrt1").collect() + assert(ddl(0).toSeq(0).toString.contains("WARN")) + assert(ddl(1).toSeq(0).toString + .contains("STORED BY 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler'")) + } + } + + test("Hive syntax DDL -- SKEWED BY") { + withTable("stored_as_dirs_multiple") { + HiveNativeCommand( + s""" + |CREATE TABLE stored_as_dirs_multiple (col1 STRING, col2 int, col3 STRING) + |SKEWED BY (col1, col2) ON (('s1',1), ('s3',3), ('s13',13), ('s78',78)) + |stored as DIRECTORIES + """.stripMargin).run(sqlContext) + val ddl = sql("show create table stored_as_dirs_multiple").collect() + assert(ddl(0).toSeq(0).toString.contains("WARN")) + assert(ddl(1).toSeq(0).toString.contains("SKEWED BY ( `col1`, `col2` )")) + } + } + + /** + * Datasource table syntax DDL + */ + test("Datasource Table DDL syntax - persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath') + """.stripMargin) + assert(compareCatalog( + TableIdentifier("jsonTable"), + sql("show create table jsonTable").collect()(0).toSeq(0).toString)) + } + } + + test("Datasource Table DDL syntax - persistent JSON with a subset of user-specified fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s""" + |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS (path '$jsonFilePath') + """.stripMargin) + assert(compareCatalog( + TableIdentifier("jsonTable"), + sql("show create table jsonTable").collect()(0).toSeq(0).toString)) + } + } + + test("Datasource Table DDL syntax - USING and OPTIONS - no user-specified schema") { + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath', + | key.key1 'value1', + | 'key.key2' 'value2' + |) + """.stripMargin) + assert(compareCatalog( + TableIdentifier("jsonTable"), + sql("show create table jsonTable").collect()(0).toSeq(0).toString)) + } + } + + test("Datasource Table DDL syntax - USING and NO options") { + withTable("parquetTable") { + sql( + s""" + |CREATE TABLE parquetTable (c1 int, c2 string, c3 long) + |USING parquet + """.stripMargin) + assert(compareCatalog( + TableIdentifier("parquetTable"), + sql("show create table parquetTable").collect()(0).toSeq(0).toString)) + } + } + + /** + * Datasource saved to table + */ + test("Save datasource to table - dataframe with a select ") { + withTable("t_datasource") { + val df = sql("select 1, 'abc'") + df.write.saveAsTable("t_datasource") + assert(compareCatalog( + TableIdentifier("t_datasource"), + sql("show create table t_datasource").collect()(0).toSeq(0).toString)) + } + } + + test("Save datasource to table - dataframe from json file") { + val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + withTable("t_datasource") { + val df = sqlContext.read.json(jsonFilePath) + df.write.format("json").saveAsTable("t_datasource") + assert(compareCatalog( + TableIdentifier("t_datasource"), + sql("show create table t_datasource").collect()(0).toSeq(0).toString)) + } + } + + test("Save datasource to table -- dataframe with user-specified schema and partitioned") { + // TODO partitioning information will be lost in the DDL, because datasource DDL syntax + // does not have a place to keep the partitioning columns. + withTable("ttt3") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + df.write + .partitionBy("a") + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("ttt3") + assert(compareCatalog( + TableIdentifier("ttt3"), + sql("show create table ttt3").collect()(0).toSeq(0).toString)) + } + } + + /** + * In order to verify whether the generated DDL from a table is correct, we can + * compare the CatalogTable generated from the existing table to the CatalogTable + * generated from the table created with the the generated DDL + * @param expectedTable + * @param actualDDL + * @return true or false + */ + private def compareCatalog(expectedTable: TableIdentifier, actualDDL: String): Boolean = { + val actualTable = expectedTable.table + "_actual" + var actual: CatalogTable = null + val expected: CatalogTable = + sqlContext.sessionState.catalog.getTableMetadata(expectedTable) + withTempDir { tmpDir => + if (actualDDL.contains("CREATE VIEW")) { + withView(actualTable) { + sql(actualDDL.replace(expectedTable.table.toLowerCase(), actualTable)) + actual = sqlContext.sessionState.catalog.getTableMetadata(TableIdentifier(actualTable)) + } + } else { + withTable(actualTable) { + var revisedActualDDL: String = null + if (expected.tableType == CatalogTableType.EXTERNAL_TABLE) { + revisedActualDDL = actualDDL.replace(expectedTable.table.toLowerCase(), actualTable) + } else { + revisedActualDDL = actualDDL + .replace(expectedTable.table.toLowerCase(), actualTable) + .replaceAll("path.*,", s"path '${tmpDir}',") + } + sql(revisedActualDDL) + actual = sqlContext.sessionState.catalog.getTableMetadata(TableIdentifier(actualTable)) + } + } + } + + if (expected.properties.get("spark.sql.sources.provider").isDefined) { + // datasource table: The generated DDL will be like: + // CREATE EXTERNAL TABLE () USING + // OPTIONS (
{Utils.bytesToString(d.toLong)} @@ -359,13 +359,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Duration @@ -374,8 +374,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -385,8 +385,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -397,8 +397,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -412,8 +412,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get, currentTime).toDouble + val schedulerDelays = validTasks.map { taskUIData: TaskUIData => + getSchedulerDelay(taskUIData.taskInfo, taskUIData.metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler DelayInput Size / RecordsOutput Size / Records @@ -461,11 +461,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -476,8 +476,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -488,25 +488,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Shuffle Write Size / RecordsShuffle spill (memory)Shuffle spill (disk)(none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + Note that it is illegal to set maximum heap size (-Xmx) settings with this option. Maximum heap + size settings can be set with spark.driver.memory in the cluster mode and through + the --driver-memory command line option in the client mode.
Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-java-options command line option or in - your default properties file.
(none) A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the - spark-submit script. Heap size settings can be set with spark.executor.memory. + Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this + option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file + used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory.
(none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use spark.driver.extraJavaOptions instead. Note that it is illegal + to set maximum heap size (-Xmx) settings with this option. Maximum heap size settings can be set + with spark.yarn.am.memory
{summary}{details}