diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 5c074d3c0fd40..4e3fe00a2e9bd 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -63,6 +63,7 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", + "spark.decisionTree", "spark.randomForest", "spark.gbt", "spark.bisectingKmeans", @@ -414,6 +415,8 @@ export("as.DataFrame", "print.summary.GeneralizedLinearRegressionModel", "read.ml", "print.summary.KSTest", + "print.summary.DecisionTreeRegressionModel", + "print.summary.DecisionTreeClassificationModel", "print.summary.RandomForestRegressionModel", "print.summary.RandomForestClassificationModel", "print.summary.GBTRegressionModel", @@ -452,6 +455,8 @@ S3method(print, structField) S3method(print, structType) S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) +S3method(print, summary.DecisionTreeRegressionModel) +S3method(print, summary.DecisionTreeClassificationModel) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) S3method(print, summary.GBTRegressionModel) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 514ca99d45cd3..5630d0c8a0df9 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1506,6 +1506,11 @@ setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.ml #' @export setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("spark.naiveBayes") }) +#' @rdname spark.decisionTree +#' @export +setGeneric("spark.decisionTree", + function(data, formula, ...) { standardGeneric("spark.decisionTree") }) + #' @rdname spark.randomForest #' @export setGeneric("spark.randomForest", diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 82279be6fbe77..2f1220a752783 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -45,6 +45,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a DecisionTreeRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeRegressionModel +#' @export +#' @note DecisionTreeRegressionModel since 2.3.0 +setClass("DecisionTreeRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a DecisionTreeClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala DecisionTreeClassificationModel +#' @export +#' @note DecisionTreeClassificationModel since 2.3.0 +setClass("DecisionTreeClassificationModel", representation(jobj = "jobj")) + # Create the summary of a tree ensemble model (eg. Random Forest, GBT) summary.treeEnsemble <- function(model) { jobj <- model@jobj @@ -81,6 +95,36 @@ print.summary.treeEnsemble <- function(x) { invisible(x) } +# Create the summary of a decision tree model +summary.decisionTree <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + maxDepth = maxDepth, + jobj = jobj) +} + +# Prints the summary of decision tree models +print.summary.decisionTree <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + #' Gradient Boosted Tree Model for Regression and Classification #' #' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a @@ -499,3 +543,199 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path function(object, path, overwrite = FALSE) { write_internal(object, path, overwrite) }) + +#' Decision Tree Model for Regression and Classification +#' +#' \code{spark.decisionTree} fits a Decision Tree Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Decision Tree +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ +#' Decision Tree Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ +#' Decision Tree Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param seed integer seed for random number generation. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.decisionTree,SparkDataFrame,formula-method +#' @return \code{spark.decisionTree} returns a fitted Decision Tree model. +#' @rdname spark.decisionTree +#' @name spark.decisionTree +#' @export +#' @examples +#' \dontrun{ +#' # fit a Decision Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.decisionTree(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Decision Tree Classification Model +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.decisionTree(df, Survived ~ Freq + Age, "classification") +#' } +#' @note spark.decisionTree since 2.3.0 +setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), impurity, + as.integer(minInstancesPerNode), as.numeric(minInfoGain), + as.integer(checkpointInterval), seed, + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("DecisionTreeClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Decision Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeRegressionModel-method +#' @export +#' @note summary(DecisionTreeRegressionModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeRegressionModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeRegressionModel" + ans + }) + +# Prints the summary of Decision Tree Regression Model + +#' @param x summary object of Decision Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeRegressionModel since 2.3.0 +print.summary.DecisionTreeRegressionModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Get the summary of a Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @aliases summary,DecisionTreeClassificationModel-method +#' @export +#' @note summary(DecisionTreeClassificationModel) since 2.3.0 +setMethod("summary", signature(object = "DecisionTreeClassificationModel"), + function(object) { + ans <- summary.decisionTree(object) + class(ans) <- "summary.DecisionTreeClassificationModel" + ans + }) + +# Prints the summary of Decision Tree Classification Model + +#' @rdname spark.decisionTree +#' @export +#' @note print.summary.DecisionTreeClassificationModel since 2.3.0 +print.summary.DecisionTreeClassificationModel <- function(x, ...) { + print.summary.decisionTree(x) +} + +# Makes predictions from a Decision Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeRegressionModel-method +#' @export +#' @note predict(DecisionTreeRegressionModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.decisionTree +#' @aliases predict,DecisionTreeClassificationModel-method +#' @export +#' @note predict(DecisionTreeClassificationModel) since 2.3.0 +setMethod("predict", signature(object = "DecisionTreeClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Decision Tree Regression or Classification model to the input path. + +#' @param object A fitted Decision Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,DecisionTreeRegressionModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeRegressionModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,DecisionTreeClassificationModel,character-method +#' @rdname spark.decisionTree +#' @export +#' @note write.ml(DecisionTreeClassificationModel, character) since 2.3.0 +setMethod("write.ml", signature(object = "DecisionTreeClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index 5dfef8625061b..a53c92c2c4815 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -32,8 +32,9 @@ #' @rdname write.ml #' @name write.ml #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, #' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, @@ -48,8 +49,9 @@ NULL #' @rdname predict #' @name predict #' @export -#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, -#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.decisionTree}, +#' @seealso \link{spark.gaussianMixture}, \link{spark.gbt}, +#' @seealso \link{spark.glm}, \link{glm}, \link{spark.isoreg}, #' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear} @@ -110,6 +112,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeRegressorWrapper")) { + new("DecisionTreeRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.DecisionTreeClassifierWrapper")) { + new("DecisionTreeClassificationModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { new("GBTRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index 146bc2878e263..b283e734cec53 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -209,4 +209,90 @@ test_that("spark.randomForest", { expect_equal(summary(model)$numFeatures, 4) }) +test_that("spark.decisionTree", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.decisionTree(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + modelPath <- tempfile(pattern = "spark-decisionTreeRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + + unlink(modelPath) + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.decisionTree(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-decisionTreeClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.decisionTree(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$maxDepth, 5) + + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.decisionTree classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.decisionTree(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + sparkR.session.stop() diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a8488d8d1b704..7e24315c5f39b 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -87,6 +87,9 @@ *:* + + org.scala-lang:scala-library + @@ -98,7 +101,7 @@ - + com.fasterxml.jackson ${spark.shade.packageName}.com.fasterxml.jackson diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 6643a8f361cdc..d430d8c5fb35a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -26,7 +26,6 @@ function getThreadDumpEnabled() { } function formatStatus(status, type) { - if (type !== 'display') return status; if (status) { return "Active" } else { @@ -417,7 +416,6 @@ $(document).ready(function () { }, {data: 'hostPort'}, {data: 'isActive', render: function (data, type, row) { - if (type !== 'display') return data; if (row.isBlacklisted) return "Blacklisted"; else return formatStatus (data, type); } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7f7921d56f49e..e193ed222e228 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -278,4 +278,13 @@ package object config { "spark.io.compression.codec.") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD = + ConfigBuilder("spark.shuffle.accurateBlockThreshold") + .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " + + "record the size accurately if it's above this config. This helps to prevent OOM by " + + "avoiding underestimating shuffle block size when fetch shuffle blocks.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(100 * 1024 * 1024) + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index b2e9a97129f08..048e0d0186594 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,8 +19,13 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.roaringbitmap.RoaringBitmap +import org.apache.spark.SparkEnv +import org.apache.spark.internal.config import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus( } /** - * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, + * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger + * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks, * plus a bitmap for tracking which blocks are empty. * * @param loc location where the task is being executed * @param numNonEmptyBlocks the number of non-empty blocks * @param emptyBlocks a bitmap tracking which blocks are empty - * @param avgSize average size of the non-empty blocks + * @param avgSize average size of the non-empty and non-huge blocks + * @param hugeBlockSizes sizes of huge blocks by their reduceId. */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, - private[this] var avgSize: Long) + private[this] var avgSize: Long, + @transient private var hugeBlockSizes: Map[Int, Byte]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization - require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0, + require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1) // For deserialization only + protected def this() = this(null, -1, null, -1, null) // For deserialization only override def location: BlockManagerId = loc override def getSizeForBlock(reduceId: Int): Long = { + assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { 0 } else { - avgSize + hugeBlockSizes.get(reduceId) match { + case Some(size) => MapStatus.decompressSize(size) + case None => avgSize + } } } @@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private ( loc.writeExternal(out) emptyBlocks.writeExternal(out) out.writeLong(avgSize) + out.writeInt(hugeBlockSizes.size) + hugeBlockSizes.foreach { kv => + out.writeInt(kv._1) + out.writeByte(kv._2) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private ( emptyBlocks = new RoaringBitmap() emptyBlocks.readExternal(in) avgSize = in.readLong() + val count = in.readInt() + val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]() + (0 until count).foreach { _ => + val block = in.readInt() + val size = in.readByte() + hugeBlockSizesArray += Tuple2(block, size) + } + hugeBlockSizes = hugeBlockSizesArray.toMap } } @@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus { // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. val emptyBlocks = new RoaringBitmap() val totalNumBlocks = uncompressedSizes.length + val threshold = Option(SparkEnv.get) + .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD)) + .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get) + val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]() while (i < totalNumBlocks) { - var size = uncompressedSizes(i) + val size = uncompressedSizes(i) if (size > 0) { numNonEmptyBlocks += 1 - totalSize += size + // Huge blocks are not included in the calculation for average size, thus size for smaller + // blocks is more accurate. + if (size < threshold) { + totalSize += size + } else { + hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i))) + } } else { emptyBlocks.add(i) } @@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus { } emptyBlocks.trim() emptyBlocks.runOptimize() - new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) + new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, + hugeBlockSizesArray.toMap) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index a0fd29c22ddca..cce7a7611b420 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -631,7 +631,8 @@ private[ui] class JobPagedTable( {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} - {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, + {UIUtils.makeProgressBar(started = job.numActiveTasks, + completed = job.completedIndices.size, failed = job.numFailedTasks, skipped = job.numSkippedTasks, reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 7370f9feb68cd..1b10feb36e439 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -423,6 +423,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { jobData.numActiveTasks -= 1 taskEnd.reason match { case Success => + jobData.completedIndices.add((taskEnd.stageId, info.index)) jobData.numCompletedTasks += 1 case kill: TaskKilled => jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 8d280bc00c3b3..048c4ad0146e2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -62,6 +62,7 @@ private[spark] object UIData { var numTasks: Int = 0, var numActiveTasks: Int = 0, var numCompletedTasks: Int = 0, + var completedIndices: OpenHashSet[(Int, Int)] = new OpenHashSet[(Int, Int)](), var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, var reasonToNumKilled: Map[String, Int] = Map.empty, diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 7a897c2b4698f..c0126e41ff7fa 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -38,6 +38,10 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def beforeAll() { super.beforeAll() + // Once 'spark.local.dir' is set, it is cached. Unless this is manually cleared + // before/after a test, it could return the same directory even if this property + // is configured. + Utils.clearLocalRootDirs() conf.set("spark.shuffle.manager", "sort") } @@ -50,6 +54,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { override def afterEach(): Unit = { try { Utils.deleteRecursively(tempDir) + Utils.clearLocalRootDirs() } finally { super.afterEach() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 759d52fca5ce1..3ec37f674c77b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} + import scala.util.Random +import org.mockito.Mockito._ import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.BlockManagerId @@ -128,4 +132,26 @@ class MapStatusSuite extends SparkFunSuite { assert(size1 === size2) assert(!success) } + + test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " + + "underestimated.") { + val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000") + val env = mock(classOf[SparkEnv]) + doReturn(conf).when(env).conf + SparkEnv.set(env) + // Value of element in sizes is equal to the corresponding index. + val sizes = (0L to 2000L).toArray + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val arrayStream = new ByteArrayOutputStream(102400) + val objectOutputStream = new ObjectOutputStream(arrayStream) + assert(status1.isInstanceOf[HighlyCompressedMapStatus]) + objectOutputStream.writeObject(status1) + objectOutputStream.flush() + val array = arrayStream.toByteArray + val objectInput = new ObjectInputStream(new ByteArrayInputStream(array)) + val status2 = objectInput.readObject().asInstanceOf[HighlyCompressedMapStatus] + (1001 to 2000).foreach { + case part => assert(status2.getSizeForBlock(part) >= sizes(part)) + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index bdd148875e38a..267c8dc1bd750 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -320,12 +320,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(5 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") find(cssSelector(".stage-progress-cell")).get.text should be ("2/2 (1 failed)") - // Ideally, the following test would pass, but currently we overcount completed tasks - // if task recomputations occur: - // find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") - // Instead, we guarantee that the total number of tasks is always correct, while the number - // of completed tasks may be higher: - find(cssSelector(".progress-cell .progress")).get.text should be ("3/2 (1 failed)") + find(cssSelector(".progress-cell .progress")).get.text should be ("2/2 (1 failed)") } val jobJson = getJson(sc.ui.get, "jobs") (jobJson \ "numTasks").extract[Int]should be (2) diff --git a/docs/configuration.md b/docs/configuration.md index 1d8d963016c71..a6b6d5dfa5f95 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -612,6 +612,15 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. + + spark.shuffle.accurateBlockThreshold + 100 * 1024 * 1024 + + When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will record the + size accurately if it's above this config. This helps to prevent OOM by avoiding + underestimating shuffle block size when fetch shuffle blocks. + + spark.io.encryption.enabled false diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 362e883e55e83..fb4621389ab92 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -71,21 +71,24 @@ To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.4 The list below highlights some of the new features and enhancements added to MLlib in the `2.2` release of Spark: -* `ALS` methods for _top-k_ recommendations for all users or items, matching the functionality - in `mllib` ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). Performance - was also improved for both `ml` and `mllib` +* [`ALS`](ml-collaborative-filtering.html) methods for _top-k_ recommendations for all + users or items, matching the functionality in `mllib` + ([SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)). + Performance was also improved for both `ml` and `mllib` ([SPARK-11968](https://issues.apache.org/jira/browse/SPARK-11968) and [SPARK-20587](https://issues.apache.org/jira/browse/SPARK-20587)) -* `Correlation` and `ChiSquareTest` stats functions for `DataFrames` +* [`Correlation`](ml-statistics.html#correlation) and + [`ChiSquareTest`](ml-statistics.html#hypothesis-testing) stats functions for `DataFrames` ([SPARK-19636](https://issues.apache.org/jira/browse/SPARK-19636) and [SPARK-19635](https://issues.apache.org/jira/browse/SPARK-19635)) -* `FPGrowth` algorithm for frequent pattern mining +* [`FPGrowth`](ml-frequent-pattern-mining.html#fp-growth) algorithm for frequent pattern mining ([SPARK-14503](https://issues.apache.org/jira/browse/SPARK-14503)) * `GLM` now supports the full `Tweedie` family ([SPARK-18929](https://issues.apache.org/jira/browse/SPARK-18929)) -* `Imputer` feature transformer to impute missing values in a dataset +* [`Imputer`](ml-features.html#imputer) feature transformer to impute missing values in a dataset ([SPARK-13568](https://issues.apache.org/jira/browse/SPARK-13568)) -* `LinearSVC` for linear Support Vector Machine classification +* [`LinearSVC`](ml-classification-regression.html#linear-support-vector-machine) + for linear Support Vector Machine classification ([SPARK-14709](https://issues.apache.org/jira/browse/SPARK-14709)) * Logistic regression now supports constraints on the coefficients during training ([SPARK-20047](https://issues.apache.org/jira/browse/SPARK-20047)) diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 7cf5b7379503f..137ef74843da5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -68,10 +68,12 @@ public List buildCommand(Map env) case "org.apache.spark.executor.CoarseGrainedExecutorBackend": javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + extraClassPath = getenv("SPARK_EXECUTOR_CLASSPATH"); break; case "org.apache.spark.executor.MesosExecutorBackend": javaOptsKeys.add("SPARK_EXECUTOR_OPTS"); memKey = "SPARK_EXECUTOR_MEMORY"; + extraClassPath = getenv("SPARK_EXECUTOR_CLASSPATH"); break; case "org.apache.spark.deploy.mesos.MesosClusterDispatcher": javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS"); diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala new file mode 100644 index 0000000000000..7f59825504d8e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class DecisionTreeClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + import DecisionTreeClassifierWrapper._ + + private val dtcModel: DecisionTreeClassificationModel = + pipeline.stages(1).asInstanceOf[DecisionTreeClassificationModel] + + lazy val numFeatures: Int = dtcModel.numFeatures + lazy val featureImportances: Vector = dtcModel.featureImportances + lazy val maxDepth: Int = dtcModel.getMaxDepth + + def summary: String = dtcModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(dtcModel.getFeaturesCol) + .drop(dtcModel.getLabelCol) + } + + override def write: MLWriter = new + DecisionTreeClassifierWrapper.DecisionTreeClassifierWrapperWriter(this) +} + +private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + seed: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setForceIndexLabel(true) + checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) + + // assemble and fit the pipeline + val dtc = new DecisionTreeClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + if (seed != null && seed.length > 0) dtc.setSeed(seed.toLong) + + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, dtc, idxToStr)) + .fit(data) + + new DecisionTreeClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[DecisionTreeClassifierWrapper] = + new DecisionTreeClassifierWrapperReader + + override def load(path: String): DecisionTreeClassifierWrapper = super.load(path) + + class DecisionTreeClassifierWrapperWriter(instance: DecisionTreeClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class DecisionTreeClassifierWrapperReader extends MLReader[DecisionTreeClassifierWrapper] { + + override def load(path: String): DecisionTreeClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new DecisionTreeClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala new file mode 100644 index 0000000000000..de712d67e6df5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeRegressionWrapper.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class DecisionTreeRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val dtrModel: DecisionTreeRegressionModel = + pipeline.stages(1).asInstanceOf[DecisionTreeRegressionModel] + + lazy val numFeatures: Int = dtrModel.numFeatures + lazy val featureImportances: Vector = dtrModel.featureImportances + lazy val maxDepth: Int = dtrModel.getMaxDepth + + def summary: String = dtrModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(dtrModel.getFeaturesCol) + } + + override def write: MLWriter = new + DecisionTreeRegressorWrapper.DecisionTreeRegressorWrapperWriter(this) +} + +private[r] object DecisionTreeRegressorWrapper extends MLReadable[DecisionTreeRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + impurity: String, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + seed: String, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): DecisionTreeRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val dtr = new DecisionTreeRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setImpurity(impurity) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) dtr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, dtr)) + .fit(data) + + new DecisionTreeRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[DecisionTreeRegressorWrapper] = new DecisionTreeRegressorWrapperReader + + override def load(path: String): DecisionTreeRegressorWrapper = super.load(path) + + class DecisionTreeRegressorWrapperWriter(instance: DecisionTreeRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class DecisionTreeRegressorWrapperReader extends MLReader[DecisionTreeRegressorWrapper] { + + override def load(path: String): DecisionTreeRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new DecisionTreeRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index b30ce12bc6cc8..ba6445a730306 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] { RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) + case "org.apache.spark.ml.r.DecisionTreeRegressorWrapper" => + DecisionTreeRegressorWrapper.load(path) + case "org.apache.spark.ml.r.DecisionTreeClassifierWrapper" => + DecisionTreeClassifierWrapper.load(path) case "org.apache.spark.ml.r.GBTRegressorWrapper" => GBTRegressorWrapper.load(path) case "org.apache.spark.ml.r.GBTClassifierWrapper" => diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 6c39fe5d84865..2b2b5fe49ea32 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -992,7 +992,16 @@ object Matrices { new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => // There is no isTranspose flag for sparse matrices in Breeze - new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) + val nsm = if (sm.rowIndices.length > sm.activeSize) { + // This sparse matrix has trailing zeros. + // Remove them by compacting the matrix. + val csm = sm.copy + csm.compact() + csm + } else { + sm + } + new SparseMatrix(nsm.rows, nsm.cols, nsm.colPtrs, nsm.rowIndices, nsm.data) case _ => throw new UnsupportedOperationException( s"Do not support conversion from type ${breeze.getClass.getName}.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 563756907d201..93c00d80974c3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -513,6 +513,26 @@ class MatricesSuite extends SparkFunSuite { Matrices.fromBreeze(sum) } + test("Test FromBreeze when Breeze.CSCMatrix.rowIndices has trailing zeros. - SPARK-20687") { + // (2, 0, 0) + // (2, 0, 0) + val mat1Brz = Matrices.sparse(2, 3, Array(0, 2, 2, 2), Array(0, 1), Array(2, 2)).asBreeze + // (2, 1E-15, 1E-15) + // (2, 1E-15, 1E-15) + val mat2Brz = Matrices.sparse(2, 3, + Array(0, 2, 4, 6), + Array(0, 0, 0, 1, 1, 1), + Array(2, 1E-15, 1E-15, 2, 1E-15, 1E-15)).asBreeze + val t1Brz = mat1Brz - mat2Brz + val t2Brz = mat2Brz - mat1Brz + // The following operations raise exceptions on un-patch Matrices.fromBreeze + val t1 = Matrices.fromBreeze(t1Brz) + val t2 = Matrices.fromBreeze(t2Brz) + // t1 == t1Brz && t2 == t2Brz + assert((t1.asBreeze - t1Brz).iterator.map((x) => math.abs(x._2)).sum < 1E-15) + assert((t2.asBreeze - t2Brz).iterator.map((x) => math.abs(x._2)).sum < 1E-15) + } + test("row/col iterator") { val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0)) val sm = dm.toSparse diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 8d25f5b3a771a..955bc9768ce77 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2082,10 +2082,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, """ A label indexer that maps a string column of labels to an ML column of label indices. If the input column is numeric, we cast it to string and index the string values. - The indices are in [0, numLabels), ordered by label frequencies. - So the most frequent label gets index 0. + The indices are in [0, numLabels). By default, this is ordered by label frequencies + so the most frequent label gets index 0. The ordering behavior is controlled by + setting :py:attr:`stringOrderType`. Its default value is 'frequencyDesc'. - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid='error') + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error", + ... stringOrderType="frequencyDesc") >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), @@ -2111,26 +2113,45 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, >>> loadedInverter = IndexToString.load(indexToStringPath) >>> loadedInverter.getLabels() == inverter.getLabels() True + >>> stringIndexer.getStringOrderType() + 'frequencyDesc' + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="error", + ... stringOrderType="alphabetDesc") + >>> model = stringIndexer.fit(stringIndDf) + >>> td = model.transform(stringIndDf) + >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), + ... key=lambda x: x[0]) + [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)] .. versionadded:: 1.4.0 """ + stringOrderType = Param(Params._dummy(), "stringOrderType", + "How to order labels of string column. The first label after " + + "ordering is assigned an index of 0. Supported options: " + + "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): + def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", + stringOrderType="frequencyDesc"): """ - __init__(self, inputCol=None, outputCol=None, handleInvalid="error") + __init__(self, inputCol=None, outputCol=None, handleInvalid="error", \ + stringOrderType="frequencyDesc") """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) - self._setDefault(handleInvalid="error") + self._setDefault(handleInvalid="error", stringOrderType="frequencyDesc") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): + def setParams(self, inputCol=None, outputCol=None, handleInvalid="error", + stringOrderType="frequencyDesc"): """ - setParams(self, inputCol=None, outputCol=None, handleInvalid="error") + setParams(self, inputCol=None, outputCol=None, handleInvalid="error", \ + stringOrderType="frequencyDesc") Sets params for this StringIndexer. """ kwargs = self._input_kwargs @@ -2139,6 +2160,20 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): def _create_model(self, java_model): return StringIndexerModel(java_model) + @since("2.3.0") + def setStringOrderType(self, value): + """ + Sets the value of :py:attr:`stringOrderType`. + """ + return self._set(stringOrderType=value) + + @since("2.3.0") + def getStringOrderType(self): + """ + Gets the value of :py:attr:`stringOrderType` or its default value 'frequencyDesc'. + """ + return self.getOrDefault(self.stringOrderType) + class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 3c3fcc8d9b8d8..2d17f95b0c44f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -323,6 +323,14 @@ def numInstances(self): """ return self._call_java("numInstances") + @property + @since("2.2.0") + def degreesOfFreedom(self): + """ + Degrees of freedom. + """ + return self._call_java("degreesOfFreedom") + @property @since("2.0.0") def devianceResiduals(self): @@ -1565,6 +1573,14 @@ def predictionCol(self): """ return self._call_java("predictionCol") + @property + @since("2.2.0") + def numInstances(self): + """ + Number of instances in DataFrame predictions. + """ + return self._call_java("numInstances") + @property @since("2.0.0") def rank(self): diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 8f5b97ccb1f85..ac7aec7b0a034 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -185,6 +185,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( def createCommand(offer: Offer, numCores: Int, taskId: String): CommandInfo = { val environment = Environment.newBuilder() + val extraClassPath = conf.getOption("spark.executor.extraClassPath") + extraClassPath.foreach { cp => + environment.addVariables( + Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) + } val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") // Set the environment variable through a command prefix diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 735c879c63c55..66b8e0a640121 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -106,6 +106,10 @@ private[spark] class MesosFineGrainedSchedulerBackend( throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") } val environment = Environment.newBuilder() + sc.conf.getOption("spark.executor.extraClassPath").foreach { cp => + environment.addVariables( + Environment.Variable.newBuilder().setName("SPARK_EXECUTOR_CLASSPATH").setValue(cp).build()) + } val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 75bf780d41424..ed423e7e334b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -366,7 +366,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, overwrite, ifPartitionNotExists = false) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a7bf81e98be8e..bf46a39862131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,9 +232,10 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } @@ -347,9 +348,10 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType)) + Seq(TypeCollection(LongType, DoubleType, DecimalType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { + case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7a54995453797..d291ca0020838 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -410,17 +410,20 @@ case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) exten * would have Map('a' -> Some('1'), 'b' -> None). * @param query the logical plan representing data to write to. * @param overwrite overwrite existing table or partitions. - * @param ifNotExists If true, only write if the table or partition does not exist. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifPartitionNotExists: Boolean) extends LogicalPlan { - assert(overwrite || !ifNotExists) - assert(partition.values.forall(_.nonEmpty) || !ifNotExists) + // IF NOT EXISTS is only valid in INSERT OVERWRITE + assert(overwrite || !ifPartitionNotExists) + // IF NOT EXISTS is only valid in static partitions + assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists) // We don't want `table` in children as sometimes we don't want to transform it. override def children: Seq[LogicalPlan] = query :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b97adf7221d18..c5d69c204642e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -303,7 +303,7 @@ object SQLConf { val HIVE_MANAGE_FILESOURCE_PARTITIONS = buildConf("spark.sql.hive.manageFilesourcePartitions") .doc("When true, enable metastore partition management for file source tables as well. " + - "This includes both datasource and converted Hive tables. When partition managment " + + "This includes both datasource and converted Hive tables. When partition management " + "is enabled, datasource tables store partition in the Hive metastore, and use the " + "metastore to prune partitions during query planning.") .booleanConf diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 31047f688600b..0896caeab8d7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val plan = testRelation2.select('c).orderBy(Floor('a).asc) val expected = testRelation2.select(c, a) - .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c) + .orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 1555dd1cf58d4..8ed7a82b943b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -219,6 +219,12 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Tan, DoubleType) } + test("cot") { + def f: (Double) => Double = (x: Double) => 1 / math.tan(x) + testUnary(Cot, f) + checkConsistencyBetweenInterpretedAndCodegen(Cot, DoubleType) + } + test("atan") { testUnary(Atan, math.atan) checkConsistencyBetweenInterpretedAndCodegen(Atan, DoubleType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index cca0291b3d5af..d78741d032f38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -176,14 +176,14 @@ class PlanParserSuite extends PlanTest { def insert( partition: Map[String, Option[String]], overwrite: Boolean = false, - ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + ifPartitionNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifPartitionNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", insert(Map.empty, overwrite = true)) assertEqual(s"insert overwrite table s partition (e = 1) if not exists $sql", - insert(Map("e" -> Option("1")), overwrite = true, ifNotExists = true)) + insert(Map("e" -> Option("1")), overwrite = true, ifPartitionNotExists = true)) assertEqual(s"insert into s $sql", insert(Map.empty)) assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", @@ -193,9 +193,9 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), false, ifPartitionNotExists = false).union( InsertIntoTable( - table("u"), Map.empty, plan2, false, ifNotExists = false))) + table("u"), Map.empty, plan2, false, ifPartitionNotExists = false))) } test ("insert with if not exists") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1732a8e08b73f..b71c5eb843eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -286,7 +286,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partition = Map.empty[String, Option[String]], query = df.logicalPlan, overwrite = mode == SaveMode.Overwrite, - ifNotExists = false) + ifPartitionNotExists = false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 375df64d39734..17671ea8685b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -111,93 +111,60 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** * @since 1.6.1 - * @deprecated use [[newIntSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newLongSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newDoubleSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newFloatSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newByteSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newShortSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newBooleanSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newStringSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() /** * @since 1.6.1 - * @deprecated use [[newProductSequenceEncoder]] + * @deprecated use [[newSequenceEncoder]] */ - implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() + def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() /** @since 2.2.0 */ - implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = - ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] = - ExpressionEncoder() + implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() // Arrays diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 470307bd940ad..bc7e73ae1ba87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -53,219 +53,288 @@ private[columnar] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: InternalRow, ordinal: Int): Unit = { - if (row.isNullAt(ordinal)) { - nullCount += 1 - // 4 bytes for null position - sizeInBytes += 4 - } + def gatherStats(row: InternalRow, ordinal: Int): Unit + + /** + * Gathers statistics information on `null`. + */ + def gatherNullStats(): Unit = { + nullCount += 1 + // 4 bytes for null position + sizeInBytes += 4 count += 1 } /** - * Column statistics represented as a single row, currently including closed lower bound, closed + * Column statistics represented as an array, currently including closed lower bound, closed * upper bound and null count. */ - def collectedStatistics: GenericInternalRow + def collectedStatistics: Array[Any] } /** * A no-op ColumnStats only used for testing purposes. */ -private[columnar] class NoopColumnStats extends ColumnStats { - override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) +private[columnar] final class NoopColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (!row.isNullAt(ordinal)) { + count += 1 + } else { + gatherNullStats + } + } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) + override def collectedStatistics: Array[Any] = Array[Any](null, null, nullCount, count, 0L) } -private[columnar] class BooleanColumnStats extends ColumnStats { +private[columnar] final class BooleanColumnStats extends ColumnStats { protected var upper = false protected var lower = true override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getBoolean(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += BOOLEAN.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Boolean): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class ByteColumnStats extends ColumnStats { +private[columnar] final class ByteColumnStats extends ColumnStats { protected var upper = Byte.MinValue protected var lower = Byte.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getByte(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += BYTE.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Byte): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BYTE.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class ShortColumnStats extends ColumnStats { +private[columnar] final class ShortColumnStats extends ColumnStats { protected var upper = Short.MinValue protected var lower = Short.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getShort(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += SHORT.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Short): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += SHORT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class IntColumnStats extends ColumnStats { +private[columnar] final class IntColumnStats extends ColumnStats { protected var upper = Int.MinValue protected var lower = Int.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getInt(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += INT.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Int): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += INT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class LongColumnStats extends ColumnStats { +private[columnar] final class LongColumnStats extends ColumnStats { protected var upper = Long.MinValue protected var lower = Long.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getLong(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Long): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += LONG.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class FloatColumnStats extends ColumnStats { +private[columnar] final class FloatColumnStats extends ColumnStats { protected var upper = Float.MinValue protected var lower = Float.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getFloat(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += FLOAT.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Float): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += FLOAT.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class DoubleColumnStats extends ColumnStats { +private[columnar] final class DoubleColumnStats extends ColumnStats { protected var upper = Double.MinValue protected var lower = Double.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getDouble(ordinal) - if (value > upper) upper = value - if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Double): Unit = { + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += DOUBLE.defaultSize + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class StringColumnStats extends ColumnStats { +private[columnar] final class StringColumnStats extends ColumnStats { protected var upper: UTF8String = null protected var lower: UTF8String = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getUTF8String(ordinal) - if (upper == null || value.compareTo(upper) > 0) upper = value.clone() - if (lower == null || value.compareTo(lower) < 0) lower = value.clone() - sizeInBytes += STRING.actualSize(row, ordinal) + val size = STRING.actualSize(row, ordinal) + gatherValueStats(value, size) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: UTF8String, size: Int): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value.clone() + if (lower == null || value.compareTo(lower) < 0) lower = value.clone() + sizeInBytes += size + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class BinaryColumnStats extends ColumnStats { +private[columnar] final class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += BINARY.actualSize(row, ordinal) + val size = BINARY.actualSize(row, ordinal) + sizeInBytes += size + count += 1 + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) } -private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { +private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { def this(dt: DecimalType) = this(dt.precision, dt.scale) protected var upper: Decimal = null protected var lower: Decimal = null override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { val value = row.getDecimal(ordinal, precision, scale) - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value // TODO: this is not right for DecimalType with precision > 18 - sizeInBytes += 8 + gatherValueStats(value) + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) + def gatherValueStats(value: Decimal): Unit = { + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += 8 + count += 1 + } + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { +private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats { val columnType = ColumnType(dataType) override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - sizeInBytes += columnType.actualSize(row, ordinal) + val size = columnType.actualSize(row, ordinal) + sizeInBytes += size + count += 1 + } else { + gatherNullStats } } - override def collectedStatistics: GenericInternalRow = - new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) + override def collectedStatistics: Array[Any] = + Array[Any](null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 0a9f3e799990f..3486a6bce8180 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -123,8 +123,8 @@ case class InMemoryRelation( batchStats.add(totalSize) - val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) - .flatMap(_.values)) + val stats = InternalRow.fromSeq( + columnBuilders.flatMap(_.columnStats.collectedStatistics)) CachedBatch(rowCount, columnBuilders.map { builder => JavaUtils.bufferToArray(builder.build()) }, stats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index bb7d1f70b62d9..14c40605ea31c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -430,6 +430,7 @@ case class DataSource( InsertIntoHadoopFsRelationCommand( outputPath = outputPath, staticPartitions = Map.empty, + ifPartitionNotExists = false, partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d307122b5c70d..21d75a404911b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -142,8 +142,8 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast parts, query, overwrite, false) if parts.isEmpty => InsertIntoDataSourceCommand(l, query, overwrite) - case InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false) => + case i @ InsertIntoTable( + l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, _) => // If the InsertIntoTable command is for a partitioned HadoopFsRelation and // the user has specified static partitions, we add a Project operator on top of the query // to include those constant column values in the query result. @@ -195,6 +195,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast InsertIntoHadoopFsRelationCommand( outputPath, staticPartitions, + i.ifPartitionNotExists, partitionSchema, t.bucketSpec, t.fileFormat, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 19b51d4d9530a..c9d31449d3629 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -37,10 +37,13 @@ import org.apache.spark.sql.execution.command._ * overwrites: when the spec is empty, all partitions are overwritten. * When it covers a prefix of the partition keys, only partitions matching * the prefix are overwritten. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoHadoopFsRelationCommand( outputPath: Path, staticPartitions: TablePartitionSpec, + ifPartitionNotExists: Boolean, partitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], fileFormat: FileFormat, @@ -61,8 +64,8 @@ case class InsertIntoHadoopFsRelationCommand( val duplicateColumns = query.schema.fieldNames.groupBy(identity).collect { case (x, ys) if ys.length > 1 => "\"" + x + "\"" }.mkString(", ") - throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + - s"cannot save to file.") + throw new AnalysisException(s"Duplicate column(s): $duplicateColumns found, " + + "cannot save to file.") } val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) @@ -76,11 +79,12 @@ case class InsertIntoHadoopFsRelationCommand( var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty + var matchingPartitions: Seq[CatalogTablePartition] = Seq.empty // When partitions are tracked by the catalog, compute all custom partition locations that // may be relevant to the insertion job. if (partitionsTrackedByCatalog) { - val matchingPartitions = sparkSession.sessionState.catalog.listPartitions( + matchingPartitions = sparkSession.sessionState.catalog.listPartitions( catalogTable.get.identifier, Some(staticPartitions)) initialMatchingPartitions = matchingPartitions.map(_.spec) customPartitionLocations = getCustomPartitionLocations( @@ -101,8 +105,12 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.ErrorIfExists, true) => throw new AnalysisException(s"path $qualifiedOutputPath already exists.") case (SaveMode.Overwrite, true) => - deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) - true + if (ifPartitionNotExists && matchingPartitions.nonEmpty) { + false + } else { + deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) + true + } case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true case (SaveMode.Ignore, exists) => diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 1920a108c6584..f7167472b05c6 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -59,3 +59,17 @@ select cot(1); select cot(null); select cot(0); select cot(-1); + +-- ceil and ceiling +select ceiling(0); +select ceiling(1); +select ceil(1234567890123456); +select ceil(12345678901234567); +select ceiling(1234567890123456); +select ceiling(12345678901234567); + +-- floor +select floor(0); +select floor(1); +select floor(1234567890123456); +select floor(12345678901234567); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index abd18211c70d8..fe52005aa91da 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -316,3 +316,83 @@ select cot(-1) struct -- !query 37 output -0.6420926159343306 + + +-- !query 38 +select ceiling(0) +-- !query 38 schema +struct +-- !query 38 output +0 + + +-- !query 39 +select ceiling(1) +-- !query 39 schema +struct +-- !query 39 output +1 + + +-- !query 40 +select ceil(1234567890123456) +-- !query 40 schema +struct +-- !query 40 output +1234567890123456 + + +-- !query 41 +select ceil(12345678901234567) +-- !query 41 schema +struct +-- !query 41 output +12345678901234567 + + +-- !query 42 +select ceiling(1234567890123456) +-- !query 42 schema +struct +-- !query 42 output +1234567890123456 + + +-- !query 43 +select ceiling(12345678901234567) +-- !query 43 schema +struct +-- !query 43 output +12345678901234567 + + +-- !query 44 +select floor(0) +-- !query 44 schema +struct +-- !query 44 output +0 + + +-- !query 45 +select floor(1) +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct +-- !query 46 output +1234567890123456 + + +-- !query 47 +select floor(12345678901234567) +-- !query 47 schema +struct +-- !query 47 output +12345678901234567 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 541565344f758..7e2949ab5aece 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -258,6 +258,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) } + test("nested sequences") { + checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) + checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index b2d04f7c5a6e3..d4e7e362c6c8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,33 +18,29 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, createRow(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, createRow(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, createRow(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, createRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, createRow(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, createRow(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, - createRow(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, createRow(null, null, 0)) - testDecimalColumnStats(createRow(null, null, 0)) - - def createRow(values: Any*): GenericInternalRow = new GenericInternalRow(values.toArray) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) + testDecimalColumnStats(Array(null, null, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], columnType: NativeColumnType[T], - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: Array[Any]): Unit = { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } @@ -60,11 +56,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum @@ -73,14 +69,14 @@ class ColumnStatsSuite extends SparkFunSuite { } def testDecimalColumnStats[T <: AtomicType, U <: ColumnStats]( - initialStatistics: GenericInternalRow): Unit = { + initialStatistics: Array[Any]): Unit = { val columnStatsName = classOf[DecimalColumnStats].getSimpleName val columnType = COMPACT_DECIMAL(15, 10) test(s"$columnStatsName: empty") { val columnStats = new DecimalColumnStats(15, 10) - columnStats.collectedStatistics.values.zip(initialStatistics.values).foreach { + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => assert(actual === expected) } } @@ -96,11 +92,11 @@ class ColumnStatsSuite extends SparkFunSuite { val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#InternalType]] val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(stats.values(0)) - assertResult(values.max(ordering), "Wrong upper bound")(stats.values(1)) - assertResult(10, "Wrong null count")(stats.values(2)) - assertResult(20, "Wrong row count")(stats.values(3)) - assertResult(stats.values(4), "Wrong size in bytes") { + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { rows.map { row => if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) }.sum diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 09a5eda6e543f..4f090d545cd18 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -160,9 +160,9 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { */ object HiveAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case InsertIntoTable(relation: CatalogRelation, partSpec, query, overwrite, ifNotExists) - if DDLUtils.isHiveTable(relation.tableMeta) => - InsertIntoHiveTable(relation.tableMeta, partSpec, query, overwrite, ifNotExists) + case InsertIntoTable(r: CatalogRelation, partSpec, query, overwrite, ifPartitionNotExists) + if DDLUtils.isHiveTable(r.tableMeta) => + InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore) @@ -207,11 +207,11 @@ case class RelationConversions( override def apply(plan: LogicalPlan): LogicalPlan = { plan transformUp { // Write path - case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifNotExists) + case InsertIntoTable(r: CatalogRelation, partition, query, overwrite, ifPartitionNotExists) // Inserting into partitioned table is not supported in Parquet/Orc data source (yet). - if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && - !r.isPartitioned && isConvertible(r) => - InsertIntoTable(convert(r), partition, query, overwrite, ifNotExists) + if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && + !r.isPartitioned && isConvertible(r) => + InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists) // Read path case relation: CatalogRelation diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 41c6b18e9d794..65e8b4e3c725c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -62,7 +62,7 @@ case class CreateHiveTableAsSelectCommand( Map(), query, overwrite = false, - ifNotExists = false)).toRdd + ifPartitionNotExists = false)).toRdd } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -78,7 +78,7 @@ case class CreateHiveTableAsSelectCommand( Map(), query, overwrite = true, - ifNotExists = false)).toRdd + ifPartitionNotExists = false)).toRdd } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 10e17c5f73433..10ce8e3730a0d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -71,14 +71,15 @@ import org.apache.spark.SparkException * }}}. * @param query the logical plan representing data to write to. * @param overwrite overwrite existing table or partitions. - * @param ifNotExists If true, only write if the table or partition does not exist. + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. */ case class InsertIntoHiveTable( table: CatalogTable, partition: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) extends RunnableCommand { + ifPartitionNotExists: Boolean) extends RunnableCommand { override protected def innerChildren: Seq[LogicalPlan] = query :: Nil @@ -375,7 +376,7 @@ case class InsertIntoHiveTable( var doHiveOverwrite = overwrite - if (oldPart.isEmpty || !ifNotExists) { + if (oldPart.isEmpty || !ifPartitionNotExists) { // SPARK-18107: Insert overwrite runs much slower than hive-client. // Newer Hive largely improves insert overwrite performance. As Spark uses older Hive // version and we may not want to catch up new Hive version every time. We delete the diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 7bd3973550043..cc80f2e481cbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -166,72 +166,54 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE tmp_table") } - test("INSERT OVERWRITE - partition IF NOT EXISTS") { - withTempDir { tmpDir => - val table = "table_with_partition" - withTable(table) { - val selQuery = s"select c1, p1, p2 from $table" - sql( - s""" - |CREATE TABLE $table(c1 string) - |PARTITIONED by (p1 string,p2 string) - |location '${tmpDir.toURI.toString}' - """.stripMargin) - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') - |SELECT 'blarr' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr", "a", "b")) - - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') - |SELECT 'blarr2' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr2", "a", "b")) + testPartitionedTable("INSERT OVERWRITE - partition IF NOT EXISTS") { tableName => + val selQuery = s"select a, b, c, d from $tableName" + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) + |SELECT 1, 4 + """.stripMargin) + checkAnswer(sql(selQuery), Row(1, 2, 3, 4)) - var e = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2) IF NOT EXISTS - |SELECT 'blarr3', 'newPartition' - """.stripMargin) - } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) + |SELECT 5, 6 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) + + val e = intercept[AnalysisException] { + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c) IF NOT EXISTS + |SELECT 7, 8, 3 + """.stripMargin) + } + assert(e.getMessage.contains( + "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]")) - e = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2) IF NOT EXISTS - |SELECT 'blarr3', 'b' - """.stripMargin) - } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [p2]")) + // If the partition already exists, the insert will overwrite the data + // unless users specify IF NOT EXISTS + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=2, c=3) IF NOT EXISTS + |SELECT 9, 10 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) - // If the partition already exists, the insert will overwrite the data - // unless users specify IF NOT EXISTS - sql( - s""" - |INSERT OVERWRITE TABLE $table - |partition (p1='a',p2='b') IF NOT EXISTS - |SELECT 'blarr3' - """.stripMargin) - checkAnswer( - sql(selQuery), - Row("blarr2", "a", "b")) - } - } + // ADD PARTITION has the same effect, even if no actual data is inserted. + sql(s"ALTER TABLE $tableName ADD PARTITION (b=21, c=31)") + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |partition (b=21, c=31) IF NOT EXISTS + |SELECT 20, 24 + """.stripMargin) + checkAnswer(sql(selQuery), Row(5, 2, 3, 6)) } test("Insert ArrayType.containsNull == false") {