diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 82d2428f3c444..15af8298ba484 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -69,6 +69,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features #' or the number of partitions are large, this param could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -98,7 +103,8 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @note spark.svmLinear since 2.2.0 setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE, - threshold = 0.0, weightCol = NULL, aggregationDepth = 2) { + threshold = 0.0, weightCol = NULL, aggregationDepth = 2, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (!is.null(weightCol) && weightCol == "") { @@ -107,10 +113,12 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu weightCol <- as.character(weightCol) } + handleInvalid <- match.arg(handleInvalid) + jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.integer(maxIter), as.numeric(tol), as.logical(standardization), as.numeric(threshold), - weightCol, as.integer(aggregationDepth)) + weightCol, as.integer(aggregationDepth), handleInvalid) new("LinearSVCModel", jobj = jobj) }) @@ -218,6 +226,11 @@ function(object, path, overwrite = FALSE) { #' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. #' The bound vector size must be equal to 1 for binomial regression, or the number #' of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -257,7 +270,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") tol = 1E-6, family = "auto", standardization = TRUE, thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, lowerBoundsOnCoefficients = NULL, upperBoundsOnCoefficients = NULL, - lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL) { + lowerBoundsOnIntercepts = NULL, upperBoundsOnIntercepts = NULL, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") row <- 0 col <- 0 @@ -304,6 +318,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") upperBoundsOnCoefficients <- as.array(as.vector(upperBoundsOnCoefficients)) } + handleInvalid <- match.arg(handleInvalid) + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", data@sdf, formula, as.numeric(regParam), as.numeric(elasticNetParam), as.integer(maxIter), @@ -312,7 +328,8 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") weightCol, as.integer(aggregationDepth), as.integer(row), as.integer(col), lowerBoundsOnCoefficients, upperBoundsOnCoefficients, - lowerBoundsOnIntercepts, upperBoundsOnIntercepts) + lowerBoundsOnIntercepts, upperBoundsOnIntercepts, + handleInvalid) new("LogisticRegressionModel", jobj = jobj) }) @@ -394,7 +411,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param stepSize stepSize parameter. #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a -#' numeric vector. +#' numeric vector. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -426,7 +448,8 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") @@ -441,10 +464,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), if (!is.null(initialWeights)) { initialWeights <- as.array(as.numeric(na.omit(initialWeights))) } + handleInvalid <- match.arg(handleInvalid) jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed, initialWeights) + as.numeric(stepSize), seed, initialWeights, handleInvalid) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) @@ -514,6 +538,11 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes @@ -543,10 +572,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' } #' @note spark.naiveBayes since 2.0.0 setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, smoothing = 1.0) { + function(data, formula, smoothing = 1.0, + handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") + handleInvalid <- match.arg(handleInvalid) jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", - formula, data@sdf, smoothing) + formula, data@sdf, smoothing, handleInvalid) new("NaiveBayesModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 75b1a74ee8c7c..33c4653f4c184 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -164,6 +164,11 @@ print.summary.decisionTree <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -205,7 +210,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, - checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -225,6 +231,7 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), new("GBTRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(lossType)) lossType <- "logistic" lossType <- match.arg(lossType, "logistic") jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", @@ -233,7 +240,8 @@ setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), as.numeric(stepSize), as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), lossType, seed, as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("GBTClassificationModel", jobj = jobj) } ) @@ -374,10 +382,11 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in classification model. -#' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -583,6 +592,11 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label +#' column of string type in classification model. +#' Supported options: "skip" (filter out rows with invalid data), +#' "error" (throw an error), "keep" (put invalid data in a special additional +#' bucket, at index numLabels). Default is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.decisionTree,SparkDataFrame,formula-method #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. @@ -617,7 +631,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, impurity = NULL, seed = NULL, minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - maxMemoryInMB = 256, cacheNodeIds = FALSE) { + maxMemoryInMB = 256, cacheNodeIds = FALSE, + handleInvalid = c("error", "keep", "skip")) { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -636,6 +651,7 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo new("DecisionTreeRegressionModel", jobj = jobj) }, classification = { + handleInvalid <- match.arg(handleInvalid) if (is.null(impurity)) impurity <- "gini" impurity <- match.arg(impurity, c("gini", "entropy")) jobj <- callJStatic("org.apache.spark.ml.r.DecisionTreeClassifierWrapper", @@ -643,7 +659,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo as.integer(maxBins), impurity, as.integer(minInstancesPerNode), as.numeric(minInfoGain), as.integer(checkpointInterval), seed, - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + as.integer(maxMemoryInMB), as.logical(cacheNodeIds), + handleInvalid) new("DecisionTreeClassificationModel", jobj = jobj) } ) diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R b/R/pkg/tests/fulltests/test_mllib_classification.R index 3d75f4ce11ec8..a4d0397236d17 100644 --- a/R/pkg/tests/fulltests/test_mllib_classification.R +++ b/R/pkg/tests/fulltests/test_mllib_classification.R @@ -70,6 +70,20 @@ test_that("spark.svmLinear", { prediction <- collect(select(predict(model, df), "prediction")) expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.svmLinear(traindf, clicked ~ ., regParam = 0.1, handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "list") + }) test_that("spark.logit", { @@ -263,6 +277,21 @@ test_that("spark.logit", { virginicaCoefs <- summary$coefficients[, "virginica"] expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.logit(traindf, clicked ~ ., regParam = 0.5) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.logit(traindf, clicked ~ ., regParam = 0.5, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") + }) test_that("spark.mlp", { @@ -344,6 +373,21 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3)) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.mlp(traindf, clicked ~ ., layers = c(1, 3), handleInvalid = "skip") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "list") + }) test_that("spark.naiveBayes", { @@ -427,6 +471,20 @@ test_that("spark.naiveBayes", { expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) expect_equal(sum(s$apriori), 1) expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.naiveBayes(traindf, clicked ~ ., smoothing = 0.0, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) sparkR.session.stop() diff --git a/R/pkg/tests/fulltests/test_mllib_tree.R b/R/pkg/tests/fulltests/test_mllib_tree.R index e31a65f8dfedb..799f94401d008 100644 --- a/R/pkg/tests/fulltests/test_mllib_tree.R +++ b/R/pkg/tests/fulltests/test_mllib_tree.R @@ -109,6 +109,20 @@ test_that("spark.gbt", { model <- spark.gbt(data, label ~ features, "classification") expect_equal(summary(model)$numFeatures, 692) } + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.gbt(traindf, clicked ~ ., type = "classification") + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.gbt(traindf, clicked ~ ., type = "classification", handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) test_that("spark.randomForest", { @@ -328,6 +342,22 @@ test_that("spark.decisionTree", { model <- spark.decisionTree(data, label ~ features, "classification") expect_equal(summary(model)$numFeatures, 4) } + + # Test unseen labels + data <- data.frame(clicked = base::sample(c(0, 1), 10, replace = TRUE), + someString = base::sample(c("this", "that"), 10, replace = TRUE), + stringsAsFactors = FALSE) + trainidxs <- base::sample(nrow(data), nrow(data) * 0.7) + traindf <- as.DataFrame(data[trainidxs, ]) + testdf <- as.DataFrame(rbind(data[-trainidxs, ], c(0, "the other"))) + model <- spark.decisionTree(traindf, clicked ~ ., type = "classification", + maxDepth = 5, maxBins = 16) + predictions <- predict(model, testdf) + expect_error(collect(predictions)) + model <- spark.decisionTree(traindf, clicked ~ ., type = "classification", + maxDepth = 5, maxBins = 16, handleInvalid = "keep") + predictions <- predict(model, testdf) + expect_equal(class(collect(predictions)$clicked[1]), "character") }) sparkR.session.stop() diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..ea9b3ce4e3522 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -18,12 +18,13 @@ package org.apache.spark.network.buffer; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -93,9 +94,9 @@ public ByteBuffer nioByteBuffer() throws IOException { @Override public InputStream createInputStream() throws IOException { - FileInputStream is = null; + InputStream is = null; try { - is = new FileInputStream(file); + is = Files.newInputStream(file.toPath()); ByteStreams.skipFully(is, offset); return new LimitedInputStream(is, length); } catch (IOException e) { @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 2f160d12af22b..66b67e282c80d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -18,11 +18,11 @@ package org.apache.spark.network.shuffle; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.WritableByteChannel; +import java.nio.file.Files; import java.util.Arrays; import org.slf4j.Logger; @@ -165,7 +165,7 @@ private class DownloadCallback implements StreamCallback { DownloadCallback(int chunkIndex) throws IOException { this.targetFile = tempShuffleFileManager.createTempShuffleFile(); - this.channel = Channels.newChannel(new FileOutputStream(targetFile)); + this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath())); this.chunkIndex = chunkIndex; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index ec57f0259d55c..39ca9ba574853 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -38,7 +38,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 323a5d3c52831..a9b5236ab8173 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -18,9 +18,9 @@ package org.apache.spark.shuffle.sort; import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; +import java.nio.channels.FileChannel; +import static java.nio.file.StandardOpenOption.*; import javax.annotation.Nullable; import scala.None$; @@ -75,7 +75,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private static final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class); private final int fileBufferSize; - private final boolean transferToEnabled; private final int numPartitions; private final BlockManager blockManager; private final Partitioner partitioner; @@ -107,7 +106,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { SparkConf conf) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; - this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; @@ -188,17 +186,21 @@ private long[] writePartitionedFile(File outputFile) throws IOException { return lengths; } - final FileOutputStream out = new FileOutputStream(outputFile, true); + // This file needs to opened in append mode in order to work around a Linux kernel bug that + // affects transferTo; see SPARK-3948 for more details. + final FileChannel out = FileChannel.open(outputFile.toPath(), WRITE, APPEND, CREATE); final long writeStartTime = System.nanoTime(); boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); if (file.exists()) { - final FileInputStream in = new FileInputStream(file); + final FileChannel in = FileChannel.open(file.toPath(), READ); boolean copyThrewException = true; try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + long size = in.size(); + Utils.copyFileStreamNIO(in, out, 0, size); + lengths[i] = size; copyThrewException = false; } finally { Closeables.close(in, copyThrewException); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 1b578491b81d7..c0ebe3cc9b792 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,6 +20,7 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import static java.nio.file.StandardOpenOption.*; import java.util.Iterator; import scala.Option; @@ -290,7 +291,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file + java.nio.file.Files.newOutputStream(outputFile.toPath()).close(); // Create an empty file return new long[partitioner.numPartitions()]; } else if (spills.length == 1) { // Here, we don't need to perform any metrics updates because the bytes written to this @@ -367,7 +368,7 @@ private long[] mergeSpillsWithFileStream( final InputStream[] spillInputStreams = new InputStream[spills.length]; final OutputStream bos = new BufferedOutputStream( - new FileOutputStream(outputFile), + java.nio.file.Files.newOutputStream(outputFile.toPath()), outputBufferSizeInBytes); // Use a counting output stream to avoid having to close the underlying file and ask // the file system for its size after each partition is written. @@ -442,11 +443,11 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { - spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); + spillInputChannels[i] = FileChannel.open(spills[i].file.toPath(), READ); } // This file needs to opened in append mode in order to work around a Linux kernel bug that // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); + mergedFileOutputChannel = FileChannel.open(outputFile.toPath(), WRITE, CREATE, APPEND); long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 3b6200e74f1e1..610ace30f8a62 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -258,6 +258,11 @@ private MapIterator(int numRecords, Location loc, boolean destructive) { this.destructive = destructive; if (destructive) { destructiveIterator = this; + // longArray will not be used anymore if destructive is true, release it now. + if (longArray != null) { + freeArray(longArray); + longArray = null; + } } } diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala index 3432700f11602..fe7438ac54f18 100644 --- a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -37,13 +37,7 @@ private[r] class JVMObjectTracker { /** * Returns the JVM object associated with the input key or None if not found. */ - final def get(id: JVMObjectId): Option[Object] = this.synchronized { - if (objMap.containsKey(id)) { - Some(objMap.get(id)) - } else { - None - } - } + final def get(id: JVMObjectId): Option[Object] = Option(objMap.get(id)) /** * Returns the JVM object associated with the input key or throws an exception if not found. @@ -67,13 +61,7 @@ private[r] class JVMObjectTracker { /** * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. */ - final def remove(id: JVMObjectId): Option[Object] = this.synchronized { - if (objMap.containsKey(id)) { - Some(objMap.remove(id)) - } else { - None - } - } + final def remove(id: JVMObjectId): Option[Object] = Option(objMap.remove(id)) /** * Number of JVM objects being tracked. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2985c90119468..5435f59ea0d28 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -55,7 +55,7 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * Doubles; and * [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that * can be saved as SequenceFiles. - * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)] + * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]) * through implicit. * * Internally, each RDD is characterized by five main properties: diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 15540485170d0..94a3a78e94165 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle import java.io._ +import java.nio.file.Files import com.google.common.io.ByteStreams @@ -141,7 +142,8 @@ private[spark] class IndexShuffleBlockResolver( val indexFile = getIndexFile(shuffleId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp))) + val out = new DataOutputStream( + new BufferedOutputStream(Files.newOutputStream(indexTmp.toPath))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. var offset = 0L @@ -196,7 +198,7 @@ private[spark] class IndexShuffleBlockResolver( // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) - val in = new DataInputStream(new FileInputStream(indexFile)) + val in = new DataInputStream(Files.newInputStream(indexFile.toPath)) try { ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8aafda5e45d52..a08563562b874 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -18,6 +18,8 @@ package org.apache.spark.util.collection import java.io._ +import java.nio.channels.{Channels, FileChannel} +import java.nio.file.StandardOpenOption import java.util.Comparator import scala.collection.BufferedIterator @@ -460,7 +462,7 @@ class ExternalAppendOnlyMap[K, V, C]( ) private var batchIndex = 0 // Which batch we're in - private var fileStream: FileInputStream = null + private var fileChannel: FileChannel = null // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams @@ -477,14 +479,14 @@ class ExternalAppendOnlyMap[K, V, C]( if (batchIndex < batchOffsets.length - 1) { if (deserializeStream != null) { deserializeStream.close() - fileStream.close() + fileChannel.close() deserializeStream = null - fileStream = null + fileChannel = null } val start = batchOffsets(batchIndex) - fileStream = new FileInputStream(file) - fileStream.getChannel.position(start) + fileChannel = FileChannel.open(file.toPath, StandardOpenOption.READ) + fileChannel.position(start) batchIndex += 1 val end = batchOffsets(batchIndex) @@ -492,7 +494,8 @@ class ExternalAppendOnlyMap[K, V, C]( assert(end >= start, "start = " + start + ", end = " + end + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val bufferedStream = new BufferedInputStream( + ByteStreams.limit(Channels.newInputStream(fileChannel), end - start)) val wrappedStream = serializerManager.wrapStream(blockId, bufferedStream) ser.deserializeStream(wrappedStream) } else { @@ -552,9 +555,9 @@ class ExternalAppendOnlyMap[K, V, C]( ds.close() deserializeStream = null } - if (fileStream != null) { - fileStream.close() - fileStream = null + if (fileChannel != null) { + fileChannel.close() + fileChannel = null } if (file.exists()) { if (!file.delete()) { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 176f84fa2a0d2..3593cfd507783 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -18,6 +18,8 @@ package org.apache.spark.util.collection import java.io._ +import java.nio.channels.{Channels, FileChannel} +import java.nio.file.StandardOpenOption import java.util.Comparator import scala.collection.mutable @@ -492,7 +494,7 @@ private[spark] class ExternalSorter[K, V, C]( // Intermediate file and deserializer streams that read from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams - var fileStream: FileInputStream = null + var fileChannel: FileChannel = null var deserializeStream = nextBatchStream() // Also sets fileStream var nextItem: (K, C) = null @@ -505,14 +507,14 @@ private[spark] class ExternalSorter[K, V, C]( if (batchId < batchOffsets.length - 1) { if (deserializeStream != null) { deserializeStream.close() - fileStream.close() + fileChannel.close() deserializeStream = null - fileStream = null + fileChannel = null } val start = batchOffsets(batchId) - fileStream = new FileInputStream(spill.file) - fileStream.getChannel.position(start) + fileChannel = FileChannel.open(spill.file.toPath, StandardOpenOption.READ) + fileChannel.position(start) batchId += 1 val end = batchOffsets(batchId) @@ -520,7 +522,8 @@ private[spark] class ExternalSorter[K, V, C]( assert(end >= start, "start = " + start + ", end = " + end + ", batchOffsets = " + batchOffsets.mkString("[", ", ", "]")) - val bufferedStream = new BufferedInputStream(ByteStreams.limit(fileStream, end - start)) + val bufferedStream = new BufferedInputStream( + ByteStreams.limit(Channels.newInputStream(fileChannel), end - start)) val wrappedStream = serializerManager.wrapStream(spill.blockId, bufferedStream) serInstance.deserializeStream(wrappedStream) @@ -610,7 +613,7 @@ private[spark] class ExternalSorter[K, V, C]( batchId = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream deserializeStream = null - fileStream = null + fileChannel = null if (ds != null) { ds.close() } diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index 4bacb385184c6..28971b87f403f 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -17,10 +17,11 @@ # limitations under the License. # -# Utility for creating well-formed pull request merges and pushing them to Apache. -# usage: ./apache-pr-merge.py (see config env vars below) +# Utility for creating well-formed pull request merges and pushing them to Apache +# Spark. +# usage: ./merge_spark_pr.py (see config env vars below) # -# This utility assumes you already have local a Spark git folder and that you +# This utility assumes you already have a local Spark git folder and that you # have added remotes corresponding to both (i) the github apache Spark # mirror and (ii) the apache git repo. diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md index c0215c8fb62f6..26025984da64c 100644 --- a/docs/rdd-programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -978,40 +978,40 @@ for details. Return a new RDD that contains the intersection of elements in the source dataset and the argument. - distinct([numTasks])) + distinct([numPartitions])) Return a new dataset that contains the distinct elements of the source dataset. - groupByKey([numTasks]) + groupByKey([numPartitions]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. - You can pass an optional numTasks argument to set a different number of tasks. + You can pass an optional numPartitions argument to set a different number of tasks. - reduceByKey(func, [numTasks]) + reduceByKey(func, [numPartitions]) When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function func, which must be of type (V,V) => V. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) + aggregateByKey(zeroValue)(seqOp, combOp, [numPartitions]) When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey, the number of reduce tasks is configurable through an optional second argument. - sortByKey([ascending], [numTasks]) + sortByKey([ascending], [numPartitions]) When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. - join(otherDataset, [numTasks]) + join(otherDataset, [numPartitions]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (V, W)) pairs with all pairs of elements for each key. Outer joins are supported through leftOuterJoin, rightOuterJoin, and fullOuterJoin. - cogroup(otherDataset, [numTasks]) + cogroup(otherDataset, [numPartitions]) When called on datasets of type (K, V) and (K, W), returns a dataset of (K, (Iterable<V>, Iterable<W>)) tuples. This operation is also called groupWith. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b5eca76480eb8..7f7cf59b7a9a8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1903,6 +1903,23 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. +**Hive UDF/UDTF/UDAF** + +Not all the APIs of the Hive UDF/UDTF/UDAF are supported by Spark SQL. Below are the unsupported APIs: + +* `getRequiredJars` and `getRequiredFiles` (`UDF` and `GenericUDF`) are functions to automatically + include additional resources required by this UDF. +* `initialize(StructObjectInspector)` in `GenericUDTF` is not supported yet. Spark SQL currently uses + a deprecated interface `initialize(ObjectInspector[])` only. +* `configure` (`GenericUDF`, `GenericUDTF`, and `GenericUDAFEvaluator`) is a function to initialize + functions with `MapredContext`, which is inapplicable to Spark. +* `close` (`GenericUDF` and `GenericUDAFEvaluator`) is a function to release associated resources. + Spark SQL does not call this function when tasks finish. +* `reset` (`GenericUDAFEvaluator`) is a function to re-initialize aggregation for reusing the same aggregation. + Spark SQL currently does not support the reuse of aggregation. +* `getWindowingEvaluator` (`GenericUDAFEvaluator`) is a function to optimize aggregation by evaluating + an aggregate over a fixed window. + # Reference ## Data Types diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index d4ddcb16bdd0e..44ae52e81cd64 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -175,7 +175,7 @@ an input DStream using data received by the instance of custom receiver, as show {% highlight scala %} // Assuming ssc is the StreamingContext val customReceiverStream = ssc.receiverStream(new CustomReceiver(host, port)) -val words = lines.flatMap(_.split(" ")) +val words = customReceiverStream.flatMap(_.split(" ")) ... {% endhighlight %} @@ -187,7 +187,7 @@ The full source code is in the example [CustomReceiver.scala]({{site.SPARK_GITHU {% highlight java %} // Assuming ssc is the JavaStreamingContext JavaDStream customReceiverStream = ssc.receiverStream(new JavaCustomReceiver(host, port)); -JavaDStream words = lines.flatMap(s -> ...); +JavaDStream words = customReceiverStream.flatMap(s -> ...); ... {% endhighlight %} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index d6ed6a4570a4a..8d556deef2be8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.functions.{col, lit} /** Params for linear SVM Classifier. */ private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol - with HasAggregationDepth { + with HasAggregationDepth with HasThreshold { /** * Param for threshold in binary classification prediction. @@ -53,11 +53,8 @@ private[classification] trait LinearSVCParams extends ClassifierParams with HasR * * @group param */ - final val threshold: DoubleParam = new DoubleParam(this, "threshold", + final override val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction applied to rawPrediction") - - /** @group getParam */ - def getThreshold: Double = $(threshold) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 6bba7f9b08dbd..21957d94e2dc3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -366,6 +366,7 @@ class LogisticRegression @Since("1.2.0") ( @Since("1.5.0") override def setThreshold(value: Double): this.type = super.setThreshold(value) + setDefault(threshold -> 0.5) @Since("1.5.0") override def getThreshold: Double = super.getThreshold diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index b6909b3386b71..d4c8e4b361959 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.feature import org.apache.hadoop.fs.Path -import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index fd9b20ed9364a..1860fe8361749 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -47,8 +47,8 @@ private[shared] object SharedParamsCodeGen { Some("\"probability\"")), ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"), ParamDesc[Double]("threshold", - "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), - isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), + "threshold in binary classification prediction, in range [0, 1]", + isValid = "ParamValidators.inRange(0, 1)", finalMethods = false, finalFields = false), ParamDesc[Array[Double]]("thresholds", "Thresholds in multi-class classification" + " to adjust the probability of predicting each class." + " Array must have length equal to the number of classes, with values > 0" + @@ -77,7 +77,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms (>= 0)", isValid = "ParamValidators.gtEq(0)"), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization (>" + - " 0)", isValid = "ParamValidators.gt(0)"), + " 0)", isValid = "ParamValidators.gt(0)", finalFields = false), ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0"), ParamDesc[String]("solver", "the solver algorithm for optimization", finalFields = false), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a29b45c3ec66c..545e45e84e9ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -162,9 +162,7 @@ private[ml] trait HasThreshold extends Params { * Param for threshold in binary classification prediction, in range [0, 1]. * @group param */ - final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) - - setDefault(threshold, 0.5) + val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1)) /** @group getParam */ def getThreshold: Double = $(threshold) @@ -352,7 +350,7 @@ private[ml] trait HasStepSize extends Params { * Param for Step size to be used for each iteration of optimization (> 0). * @group param */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", ParamValidators.gt(0)) + val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization (> 0)", ParamValidators.gt(0)) /** @group getParam */ final def getStepSize: Double = $(stepSize) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala index 7f59825504d8e..a90cae5869b2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/DecisionTreeClassificationWrapper.scala @@ -73,11 +73,13 @@ private[r] object DecisionTreeClassifierWrapper extends MLReadable[DecisionTreeC checkpointInterval: Int, seed: String, maxMemoryInMB: Int, - cacheNodeIds: Boolean): DecisionTreeClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): DecisionTreeClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala index c07eadb30a4d2..ecaeec5a7791a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -78,11 +78,13 @@ private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] seed: String, subsamplingRate: Double, maxMemoryInMB: Int, - cacheNodeIds: Boolean): GBTClassifierWrapper = { + cacheNodeIds: Boolean, + handleInvalid: String): GBTClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala index 0dd1f1146fbf8..7a22a71c3a819 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala @@ -79,12 +79,14 @@ private[r] object LinearSVCWrapper standardization: Boolean, threshold: Double, weightCol: String, - aggregationDepth: Int + aggregationDepth: Int, + handleInvalid: String ): LinearSVCWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index b96481acf46d7..18acf7d21656f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -103,12 +103,14 @@ private[r] object LogisticRegressionWrapper lowerBoundsOnCoefficients: Array[Double], upperBoundsOnCoefficients: Array[Double], lowerBoundsOnIntercepts: Array[Double], - upperBoundsOnIntercepts: Array[Double] + upperBoundsOnIntercepts: Array[Double], + handleInvalid: String ): LogisticRegressionWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 48c87743dee60..62f642142701b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -62,7 +62,7 @@ private[r] object MultilayerPerceptronClassifierWrapper val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" - def fit( + def fit( // scalastyle:ignore data: DataFrame, formula: String, blockSize: Int, @@ -72,11 +72,13 @@ private[r] object MultilayerPerceptronClassifierWrapper tol: Double, stepSize: Double, seed: String, - initialWeights: Array[Double] + initialWeights: Array[Double], + handleInvalid: String ): MultilayerPerceptronClassifierWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 0afea4be3d1dd..fbf9f462ff5f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -57,10 +57,15 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" - def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = { + def fit( + formula: String, + data: DataFrame, + smoothing: Double, + handleInvalid: String): NaiveBayesWrapper = { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) + .setHandleInvalid(handleInvalid) checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 3fc3ac58b7795..47079d9c6bb1c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -458,7 +458,7 @@ private[ml] trait RandomForestRegressorParams * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -484,13 +484,10 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { * (default = 0.1) * @group param */ - final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " + + final override val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " + "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.", ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) - /** @group getParam */ - final def getStepSize: Double = $(stepSize) - /** * @deprecated This method is deprecated and will be removed in 3.0.0. * @group setParam diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 4af6f71e19257..ab1617ba47221 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -63,7 +63,7 @@ def numClasses(self): @inherit_doc class LinearSVC(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasRawPredictionCol, HasFitIntercept, HasStandardization, - HasWeightCol, HasAggregationDepth, JavaMLWritable, JavaMLReadable): + HasWeightCol, HasAggregationDepth, HasThreshold, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -153,18 +153,6 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LinearSVCModel(java_model) - def setThreshold(self, value): - """ - Sets the value of :py:attr:`threshold`. - """ - return self._set(threshold=value) - - def getThreshold(self): - """ - Gets the value of threshold or its default value. - """ - return self.getOrDefault(self.threshold) - class LinearSVCModel(JavaModel, JavaClassificationModel, JavaMLWritable, JavaMLReadable): """ @@ -1030,6 +1018,11 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol "Supported options: " + ", ".join(GBTParams.supportedLossTypes), typeConverter=TypeConverters.toString) + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator.", + typeConverter=TypeConverters.toFloat) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f0ff7a5f59abf..2cc623427edc8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1014,6 +1014,11 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, "Supported options: " + ", ".join(GBTParams.supportedLossTypes), typeConverter=TypeConverters.toString) + stepSize = Param(Params._dummy(), "stepSize", + "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " + + "the contribution of each estimator.", + typeConverter=TypeConverters.toFloat) + @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 54756edd9345d..cfd9c558ff67e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1241,26 +1241,29 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - with self.assertRaises(ValueError): - struct1 = StructType().add("name") + self.assertRaises(ValueError, lambda: StructType().add("name")) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) for field in struct1: @@ -1273,12 +1276,9 @@ def test_struct_type(self): self.assertIs(struct1["f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - with self.assertRaises(KeyError): - not_a_field = struct1["f9"] - with self.assertRaises(IndexError): - not_a_field = struct1[9] - with self.assertRaises(TypeError): - not_a_field = struct1[9.9] + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index c376805c32738..ecb8eb9a2f2fa 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -446,9 +446,12 @@ class StructType(DataType): This is the data type representing a :class:`Row`. - Iterating a :class:`StructType` will iterate its :class:`StructField`s. + Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. + .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead + to get a list of field names. + >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) @@ -563,6 +566,16 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def fieldNames(self): + """ + Returns all field names in a list. + + >>> struct = StructType([StructField("f1", StringType(), True)]) + >>> struct.fieldNames() + ['f1'] + """ + return list(self.names) + def needConversion(self): # We need convert Row()/namedtuple into tuple() return True diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index ca6a3ef3ebbb5..0387b44dbcc10 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -474,6 +474,7 @@ private[spark] class ApplicationMaster( addAmIpFilter() registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), securityMgr) + registered = true // In client mode the actor will stop the reporter thread. reporterThread.join() diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index ef9f88a9026c9..4534b7dcf6399 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -474,7 +474,7 @@ identifierComment relationPrimary : tableIdentifier sample? tableAlias #tableName - | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery + | '(' queryNoWith ')' sample? tableAlias #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 | functionTable #tableValuedFunction diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 913d846a8c23b..a6d297cfd6538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -141,6 +141,7 @@ class Analyzer( ResolveFunctions :: ResolveAliases :: ResolveSubquery :: + ResolveSubqueryColumnAliases :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: @@ -1323,6 +1324,30 @@ class Analyzer( } } + /** + * Replaces unresolved column aliases for a subquery with projections. + */ + object ResolveSubqueryColumnAliases extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case u @ UnresolvedSubqueryColumnAliases(columnNames, child) if child.resolved => + // Resolves output attributes if a query has alias names in its subquery: + // e.g., SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) + val outputAttrs = child.output + // Checks if the number of the aliases equals to the number of output columns + // in the subquery. + if (columnNames.size != outputAttrs.size) { + u.failAnalysis("Number of column aliases does not match number of columns. " + + s"Number of column aliases: ${columnNames.size}; " + + s"number of columns: ${outputAttrs.size}.") + } + val aliases = outputAttrs.zip(columnNames).map { case (attr, aliasName) => + Alias(attr, aliasName)() + } + Project(aliases, child) + } + } + /** * Turns projections that contain aggregate expressions into aggregations. */ @@ -2234,7 +2259,9 @@ object EliminateUnions extends Rule[LogicalPlan] { /** * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level * expression in Project(project list) or Aggregate(aggregate expressions) or - * Window(window expressions). + * Window(window expressions). Notice that if an expression has other expression parameters which + * are not in its `children`, e.g. `RuntimeReplaceable`, the transformation for Aliases in this + * rule can't work for those parameters. */ object CleanupAliases extends Rule[LogicalPlan] { private def trimAliases(e: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 85c52792ef659..e235689cc36ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -108,11 +108,9 @@ trait CheckAnalysis extends PredicateHelper { case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") - case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, - SpecifiedWindowFrame(frame, - FrameBoundary(l), - FrameBoundary(h)))) - if order.isEmpty || frame != RowFrame || l != h => + case w @ WindowExpression(_: OffsetWindowFunction, + WindowSpecDefinition(_, order, frame: SpecifiedWindowFrame)) + if order.isEmpty || !frame.isOffset => failAnalysis("An offset window function can only be evaluated in an ordered " + s"row-based window frame with a single offset: $w") @@ -121,15 +119,10 @@ trait CheckAnalysis extends PredicateHelper { // function. e match { case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + w case _ => failAnalysis(s"Expression '$e' not supported within a window function.") } - // Make sure the window specification is valid. - s.validate match { - case Some(m) => - failAnalysis(s"Window specification $s is not valid because $m") - case None => w - } case s: SubqueryExpression => checkSubqueryExpression(operator, s) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a78e1c98e89de..25af014f67fe9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,6 +59,7 @@ object TypeCoercion { PropagateTypes :: ImplicitTypeCasts :: DateTimeOperations :: + WindowFrameCoercion :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -805,4 +806,26 @@ object TypeCoercion { Option(ret) } } + + /** + * Cast WindowFrame boundaries to the type they operate upon. + */ + object WindowFrameCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) + if order.resolved => + s.copy(frameSpecification = SpecifiedWindowFrame( + RangeFrame, + createBoundaryCast(lower, order.dataType), + createBoundaryCast(upper, order.dataType))) + } + + private def createBoundaryCast(boundary: Expression, dt: DataType): Expression = { + boundary match { + case e: SpecialFrameBoundary => e + case e: Expression if e.dataType != dt && Cast.canCast(e.dataType, dt) => Cast(e, dt) + case _ => boundary + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index fb322697c7c68..b7a704dc8453a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.types.{DataType, Metadata, StructType} @@ -422,6 +422,27 @@ case class UnresolvedAlias( override lazy val resolved = false } +/** + * Aliased column names resolved by positions for subquery. We could add alias names for output + * columns in the subquery: + * {{{ + * // Assign alias names for output columns + * SELECT col1, col2 FROM testData AS t(col1, col2); + * }}} + * + * @param outputColumnNames the [[LogicalPlan]] on which this subquery column aliases apply. + * @param child the logical plan of this subquery. + */ +case class UnresolvedSubqueryColumnAliases( + outputColumnNames: Seq[String], + child: LogicalPlan) + extends UnaryNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + /** * Holds the deserializer expression and the attributes that are available during the resolution * for it. Deserializer expression is a special kind of expression that is not always resolved by diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b847ef7bfaa97..74c4cddf2b47e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -241,6 +241,10 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override def nullable: Boolean = child.nullable override def foldable: Boolean = child.foldable override def dataType: DataType = child.dataType + // As this expression gets replaced at optimization with its `child" expression, + // two `RuntimeReplaceable` are considered to be semantically equal if their "child" expressions + // are semantically equal. + override lazy val canonicalized: Expression = child.canonicalized } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 4c8b177237d23..1a48995358af7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -74,6 +74,13 @@ package object expressions { def initialize(partitionIndex: Int): Unit = {} } + /** + * An identity projection. This returns the input row. + */ + object IdentityProjection extends Projection { + override def apply(row: InternalRow): InternalRow = row + } + /** * Converts a [[InternalRow]] to another Row given a sequence of expression that define each * column of the new row. If the schema of the input row is specified, then the given expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 88afd43223d1d..a829dccfd3e36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} @@ -43,34 +42,7 @@ case class WindowSpecDefinition( orderSpec: Seq[SortOrder], frameSpecification: WindowFrame) extends Expression with WindowSpec with Unevaluable { - def validate: Option[String] = frameSpecification match { - case UnspecifiedFrame => - Some("Found a UnspecifiedFrame. It should be converted to a SpecifiedWindowFrame " + - "during analysis. Please file a bug report.") - case frame: SpecifiedWindowFrame => frame.validate.orElse { - def checkValueBasedBoundaryForRangeFrame(): Option[String] = { - if (orderSpec.length > 1) { - // It is not allowed to have a value-based PRECEDING and FOLLOWING - // as the boundary of a Range Window Frame. - Some("This Range Window Frame only accepts at most one ORDER BY expression.") - } else if (orderSpec.nonEmpty && !orderSpec.head.dataType.isInstanceOf[NumericType]) { - Some("The data type of the expression in the ORDER BY clause should be a numeric type.") - } else { - None - } - } - - (frame.frameType, frame.frameStart, frame.frameEnd) match { - case (RangeFrame, vp: ValuePreceding, _) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, vf: ValueFollowing, _) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, _, vp: ValuePreceding) => checkValueBasedBoundaryForRangeFrame() - case (RangeFrame, _, vf: ValueFollowing) => checkValueBasedBoundaryForRangeFrame() - case (_, _, _) => None - } - } - } - - override def children: Seq[Expression] = partitionSpec ++ orderSpec + override def children: Seq[Expression] = partitionSpec ++ orderSpec :+ frameSpecification override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && @@ -78,23 +50,46 @@ case class WindowSpecDefinition( override def nullable: Boolean = true override def foldable: Boolean = false - override def dataType: DataType = throw new UnsupportedOperationException + override def dataType: DataType = throw new UnsupportedOperationException("dataType") - override def sql: String = { - val partition = if (partitionSpec.isEmpty) { - "" - } else { - "PARTITION BY " + partitionSpec.map(_.sql).mkString(", ") + " " + override def checkInputDataTypes(): TypeCheckResult = { + frameSpecification match { + case UnspecifiedFrame => + TypeCheckFailure( + "Cannot use an UnspecifiedFrame. This should have been converted during analysis. " + + "Please file a bug report.") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && !f.isUnbounded && + orderSpec.isEmpty => + TypeCheckFailure( + "A range window frame cannot be used in an unordered window specification.") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && + orderSpec.size > 1 => + TypeCheckFailure( + s"A range window frame with value boundaries cannot be used in a window specification " + + s"with multiple order by expressions: ${orderSpec.mkString(",")}") + case f: SpecifiedWindowFrame if f.frameType == RangeFrame && f.isValueBound && + !isValidFrameType(f.valueBoundary.head.dataType) => + TypeCheckFailure( + s"The data type '${orderSpec.head.dataType}' used in the order specification does " + + s"not match the data type '${f.valueBoundary.head.dataType}' which is used in the " + + "range frame.") + case _ => TypeCheckSuccess } + } - val order = if (orderSpec.isEmpty) { - "" - } else { - "ORDER BY " + orderSpec.map(_.sql).mkString(", ") + " " + override def sql: String = { + def toSql(exprs: Seq[Expression], prefix: String): Seq[String] = { + Seq(exprs).filter(_.nonEmpty).map(_.map(_.sql).mkString(prefix, ", ", "")) } - s"($partition$order${frameSpecification.toString})" + val elements = + toSql(partitionSpec, "PARTITION BY ") ++ + toSql(orderSpec, "ORDER BY ") ++ + Seq(frameSpecification.sql) + elements.mkString("(", " ", ")") } + + private def isValidFrameType(ft: DataType): Boolean = orderSpec.head.dataType == ft } /** @@ -106,22 +101,26 @@ case class WindowSpecReference(name: String) extends WindowSpec /** * The trait used to represent the type of a Window Frame. */ -sealed trait FrameType +sealed trait FrameType { + def inputType: AbstractDataType + def sql: String +} /** - * RowFrame treats rows in a partition individually. When a [[ValuePreceding]] - * or a [[ValueFollowing]] is used as its [[FrameBoundary]], the value is considered - * as a physical offset. + * RowFrame treats rows in a partition individually. Values used in a row frame are considered + * to be physical offsets. * For example, `ROW BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a 3-row frame, * from the row that precedes the current row to the row that follows the current row. */ -case object RowFrame extends FrameType +case object RowFrame extends FrameType { + override def inputType: AbstractDataType = IntegerType + override def sql: String = "ROWS" +} /** - * RangeFrame treats rows in a partition as groups of peers. - * All rows having the same `ORDER BY` ordering are considered as peers. - * When a [[ValuePreceding]] or a [[ValueFollowing]] is used as its [[FrameBoundary]], - * the value is considered as a logical offset. + * RangeFrame treats rows in a partition as groups of peers. All rows having the same `ORDER BY` + * ordering are considered as peers. Values used in a range frame are considered to be logical + * offsets. * For example, assuming the value of the current row's `ORDER BY` expression `expr` is `v`, * `RANGE BETWEEN 1 PRECEDING AND 1 FOLLOWING` represents a frame containing rows whose values * `expr` are in the range of [v-1, v+1]. @@ -129,138 +128,144 @@ case object RowFrame extends FrameType * If `ORDER BY` clause is not defined, all rows in the partition are considered as peers * of the current row. */ -case object RangeFrame extends FrameType - -/** - * The trait used to represent the type of a Window Frame Boundary. - */ -sealed trait FrameBoundary { - def notFollows(other: FrameBoundary): Boolean +case object RangeFrame extends FrameType { + override def inputType: AbstractDataType = NumericType + override def sql: String = "RANGE" } /** - * Extractor for making working with frame boundaries easier. + * The trait used to represent special boundaries used in a window frame. */ -object FrameBoundary { - def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None - } +sealed trait SpecialFrameBoundary extends Expression with Unevaluable { + override def children: Seq[Expression] = Nil + override def dataType: DataType = NullType + override def foldable: Boolean = false + override def nullable: Boolean = false } -/** UNBOUNDED PRECEDING boundary. */ -case object UnboundedPreceding extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => true - case vp: ValuePreceding => true - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = "UNBOUNDED PRECEDING" +/** UNBOUNDED boundary. */ +case object UnboundedPreceding extends SpecialFrameBoundary { + override def sql: String = "UNBOUNDED PRECEDING" } -/** PRECEDING boundary. */ -case class ValuePreceding(value: Int) extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case ValuePreceding(anotherValue) => value >= anotherValue - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = s"$value PRECEDING" +case object UnboundedFollowing extends SpecialFrameBoundary { + override def sql: String = "UNBOUNDED FOLLOWING" } /** CURRENT ROW boundary. */ -case object CurrentRow extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => true - case vf: ValueFollowing => true - case UnboundedFollowing => true - } - - override def toString: String = "CURRENT ROW" -} - -/** FOLLOWING boundary. */ -case class ValueFollowing(value: Int) extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => false - case ValueFollowing(anotherValue) => value <= anotherValue - case UnboundedFollowing => true - } - - override def toString: String = s"$value FOLLOWING" -} - -/** UNBOUNDED FOLLOWING boundary. */ -case object UnboundedFollowing extends FrameBoundary { - def notFollows(other: FrameBoundary): Boolean = other match { - case UnboundedPreceding => false - case vp: ValuePreceding => false - case CurrentRow => false - case vf: ValueFollowing => false - case UnboundedFollowing => true - } - - override def toString: String = "UNBOUNDED FOLLOWING" +case object CurrentRow extends SpecialFrameBoundary { + override def sql: String = "CURRENT ROW" } /** * Represents a window frame. */ -sealed trait WindowFrame +sealed trait WindowFrame extends Expression with Unevaluable { + override def children: Seq[Expression] = Nil + override def dataType: DataType = throw new UnsupportedOperationException("dataType") + override def foldable: Boolean = false + override def nullable: Boolean = false +} /** Used as a placeholder when a frame specification is not defined. */ case object UnspecifiedFrame extends WindowFrame -/** A specified Window Frame. */ +/** + * A specified Window Frame. The val lower/uppper can be either a foldable [[Expression]] or a + * [[SpecialFrameBoundary]]. + */ case class SpecifiedWindowFrame( frameType: FrameType, - frameStart: FrameBoundary, - frameEnd: FrameBoundary) extends WindowFrame { - - /** If this WindowFrame is valid or not. */ - def validate: Option[String] = (frameType, frameStart, frameEnd) match { - case (_, UnboundedFollowing, _) => - Some(s"$UnboundedFollowing is not allowed as the start of a Window Frame.") - case (_, _, UnboundedPreceding) => - Some(s"$UnboundedPreceding is not allowed as the end of a Window Frame.") - // case (RowFrame, start, end) => ??? RowFrame specific rule - // case (RangeFrame, start, end) => ??? RangeFrame specific rule - case (_, start, end) => - if (start.notFollows(end)) { - None - } else { - val reason = - s"The end of this Window Frame $end is smaller than the start of " + - s"this Window Frame $start." - Some(reason) - } + lower: Expression, + upper: Expression) + extends WindowFrame { + + override def children: Seq[Expression] = lower :: upper :: Nil + + lazy val valueBoundary: Seq[Expression] = + children.filterNot(_.isInstanceOf[SpecialFrameBoundary]) + + override def checkInputDataTypes(): TypeCheckResult = { + // Check lower value. + val lowerCheck = checkBoundary(lower, "lower") + if (lowerCheck.isFailure) { + return lowerCheck + } + + // Check upper value. + val upperCheck = checkBoundary(upper, "upper") + if (upperCheck.isFailure) { + return upperCheck + } + + // Check combination (of expressions). + (lower, upper) match { + case (l: Expression, u: Expression) if !isValidFrameBoundary(l, u) => + TypeCheckFailure(s"Window frame upper bound '$upper' does not followes the lower bound " + + s"'$lower'.") + case (l: SpecialFrameBoundary, _) => TypeCheckSuccess + case (_, u: SpecialFrameBoundary) => TypeCheckSuccess + case (l: Expression, u: Expression) if l.dataType != u.dataType => + TypeCheckFailure( + s"Window frame bounds '$lower' and '$upper' do no not have the same data type: " + + s"'${l.dataType.catalogString}' <> '${u.dataType.catalogString}'") + case (l: Expression, u: Expression) if isGreaterThan(l, u) => + TypeCheckFailure( + "The lower bound of a window frame must be less than or equal to the upper bound") + case _ => TypeCheckSuccess + } + } + + override def sql: String = { + val lowerSql = boundarySql(lower) + val upperSql = boundarySql(upper) + s"${frameType.sql} BETWEEN $lowerSql AND $upperSql" } - override def toString: String = frameType match { - case RowFrame => s"ROWS BETWEEN $frameStart AND $frameEnd" - case RangeFrame => s"RANGE BETWEEN $frameStart AND $frameEnd" + def isUnbounded: Boolean = lower == UnboundedPreceding && upper == UnboundedFollowing + + def isValueBound: Boolean = valueBoundary.nonEmpty + + def isOffset: Boolean = (lower, upper) match { + case (l: Expression, u: Expression) => frameType == RowFrame && l == u + case _ => false + } + + private def boundarySql(expr: Expression): String = expr match { + case e: SpecialFrameBoundary => e.sql + case UnaryMinus(n) => n.sql + " PRECEDING" + case e: Expression => e.sql + " FOLLOWING" + } + + private def isGreaterThan(l: Expression, r: Expression): Boolean = { + GreaterThan(l, r).eval().asInstanceOf[Boolean] + } + + private def checkBoundary(b: Expression, location: String): TypeCheckResult = b match { + case _: SpecialFrameBoundary => TypeCheckSuccess + case e: Expression if !e.foldable => + TypeCheckFailure(s"Window frame $location bound '$e' is not a literal.") + case e: Expression if !frameType.inputType.acceptsType(e.dataType) => + TypeCheckFailure( + s"The data type of the $location bound '${e.dataType} does not match " + + s"the expected data type '${frameType.inputType}'.") + case _ => TypeCheckSuccess + } + + private def isValidFrameBoundary(l: Expression, u: Expression): Boolean = { + (l, u) match { + case (UnboundedFollowing, _) => false + case (_, UnboundedPreceding) => false + case _ => true + } } } object SpecifiedWindowFrame { /** - * * @param hasOrderSpecification If the window spec has order by expressions. * @param acceptWindowFrame If the window function accepts user-specified frame. - * @return + * @return the default window frame. */ def defaultWindowFrame( hasOrderSpecification: Boolean, @@ -351,20 +356,25 @@ abstract class OffsetWindowFunction override def nullable: Boolean = default == null || default.nullable || input.nullable - override lazy val frame = { - // This will be triggered by the Analyzer. - val offsetValue = offset.eval() match { - case o: Int => o - case x => throw new AnalysisException( - s"Offset expression must be a foldable integer expression: $x") - } + override lazy val frame: WindowFrame = { val boundary = direction match { - case Ascending => ValueFollowing(offsetValue) - case Descending => ValuePreceding(offsetValue) + case Ascending => offset + case Descending => UnaryMinus(offset) } SpecifiedWindowFrame(RowFrame, boundary, boundary) } + override def checkInputDataTypes(): TypeCheckResult = { + val check = super.checkInputDataTypes() + if (check.isFailure) { + check + } else if (!offset.foldable) { + TypeCheckFailure(s"Offset expression '$offset' must be a literal.") + } else { + TypeCheckSuccess + } + } + override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 45c1d3d430e0d..07578261781b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -750,20 +750,28 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different - * hooks. + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} */ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - val alias = if (ctx.strictIdentifier == null) { + val alias = if (ctx.tableAlias.strictIdentifier == null) { // For un-aliased subqueries, use a default alias name that is not likely to conflict with // normal subquery names, so that parent operators can only access the columns in subquery by // unqualified names. Users can still use this special qualifier to access columns if they // know it, but that's not recommended. "__auto_generated_subquery_name" } else { - ctx.strictIdentifier.getText + ctx.tableAlias.strictIdentifier.getText + } + val subquery = SubqueryAlias(alias, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) + if (ctx.tableAlias.identifierList != null) { + val columnAliases = visitIdentifierList(ctx.tableAlias.identifierList) + UnresolvedSubqueryColumnAliases(columnAliases, subquery) + } else { + subquery } - - SubqueryAlias(alias, plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample)) } /** @@ -1179,32 +1187,26 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value - * Preceding/Following boundaries. These expressions must be constant (foldable) and return an - * integer value. + * Create or resolve a frame boundary expressions. */ - override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { - // We currently only allow foldable integers. - def value: Int = { + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { val e = expression(ctx.expression) - validate(e.resolved && e.foldable && e.dataType == IntegerType, - "Frame bound value must be a constant integer.", - ctx) - e.eval().asInstanceOf[Int] + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e } - // Create the FrameBoundary ctx.boundType.getType match { case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => UnboundedPreceding case SqlBaseParser.PRECEDING => - ValuePreceding(value) + UnaryMinus(value) case SqlBaseParser.CURRENT => CurrentRow case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => UnboundedFollowing case SqlBaseParser.FOLLOWING => - ValueFollowing(value) + value } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 7375a0bcbae75..b6889f21cc6ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -688,8 +688,6 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case id: FunctionIdentifier => true case spec: BucketSpec => true case catalog: CatalogTable => true - case boundary: FrameBoundary => true - case frame: WindowFrame => true case partition: Partitioning => true case resource: FunctionResource => true case broadcast: BroadcastMode => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 54006e20a3eb6..b314ef4e35d6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import scala.collection.mutable.ArrayBuffer +import scala.util.control.NonFatal import scala.util.Try import org.json4s.JsonDSL._ @@ -467,10 +468,16 @@ object StructType extends AbstractDataType { leftFields.foreach { case leftField @ StructField(leftName, leftType, leftNullable, _) => rightMapped.get(leftName) - .map { case rightField @ StructField(_, rightType, rightNullable, _) => - leftField.copy( - dataType = merge(leftType, rightType), - nullable = leftNullable || rightNullable) + .map { case rightField @ StructField(rightName, rightType, rightNullable, _) => + try { + leftField.copy( + dataType = merge(leftType, rightType), + nullable = leftNullable || rightNullable) + } catch { + case NonFatal(e) => + throw new SparkException(s"Failed to merge fields '$leftName' and " + + s"'$rightName'. " + e.getMessage) + } } .orElse { Some(leftField) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 7311dc3899e53..4e0613619add6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -190,7 +190,7 @@ class AnalysisErrorSuite extends AnalysisTest { WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, - SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + SpecifiedWindowFrame(RangeFrame, Literal(1), Literal(2)))).as('window)), "window frame" :: "must match the required frame" :: Nil) errorTest( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index be26b1b26f175..9bcf4773fa903 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -470,4 +470,24 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { Seq("Number of column aliases does not match number of columns. Table name: TaBlE3; " + "number of column aliases: 5; number of columns: 4.")) } + + test("SPARK-20962 Support subquery column aliases in FROM clause") { + def tableColumnsWithAliases(outputNames: Seq[String]): LogicalPlan = { + UnresolvedSubqueryColumnAliases( + outputNames, + SubqueryAlias( + "t", + UnresolvedRelation(TableIdentifier("TaBlE3"))) + ).select(star()) + } + assertAnalysisSuccess(tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: Nil)) + assertAnalysisError( + tableColumnsWithAliases("col1" :: Nil), + Seq("Number of column aliases does not match number of columns. " + + "Number of column aliases: 1; number of columns: 4.")) + assertAnalysisError( + tableColumnsWithAliases("col1" :: "col2" :: "col3" :: "col4" :: "col5" :: Nil), + Seq("Number of column aliases does not match number of columns. " + + "Number of column aliases: 5; number of columns: 4.")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index b3994ab0828ad..d62e3b6dfe34f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1109,6 +1109,42 @@ class TypeCoercionSuite extends AnalysisTest { EqualTo(Literal(Array(1, 2)), Literal("123")), EqualTo(Literal(Array(1, 2)), Literal("123"))) } + + test("cast WindowFrame boundaries to the type they operate upon") { + // Can cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(3), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, Cast(3, LongType), Literal(2147483648L))) + ) + // Cannot cast frame boundaries to order dataType. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal.default(DateType), Ascending)), + SpecifiedWindowFrame(RangeFrame, Literal(10.0), Literal(2147483648L))) + ) + // Should not cast SpecialFrameBoundary. + ruleTest(WindowFrameCoercion, + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)), + windowSpec( + Seq(UnresolvedAttribute("a")), + Seq(SortOrder(Literal(1L), Ascending)), + SpecifiedWindowFrame(RangeFrame, CurrentRow, UnboundedFollowing)) + ) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 45f9f72dccc45..76c79b3d0760c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -267,16 +267,17 @@ class ExpressionParserSuite extends PlanTest { // Range/Row val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) val boundaries = Seq( - ("10 preceding", ValuePreceding(10), CurrentRow), - ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("10 preceding", -Literal(10), CurrentRow), + ("2147483648 preceding", -Literal(2147483648L), CurrentRow), + ("3 + 1 following", Add(Literal(3), Literal(1)), CurrentRow), ("unbounded preceding", UnboundedPreceding, CurrentRow), ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), ("between unbounded preceding and unbounded following", UnboundedPreceding, UnboundedFollowing), - ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), - ("between current row and 5 following", CurrentRow, ValueFollowing(5)), - ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ("between 10 preceding and current row", -Literal(10), CurrentRow), + ("between current row and 5 following", CurrentRow, Literal(5)), + ("between 10 preceding and 5 following", -Literal(10), Literal(5)) ) frameTypes.foreach { case (frameTypeSql, frameType) => @@ -288,13 +289,9 @@ class ExpressionParserSuite extends PlanTest { } } - // We cannot use non integer constants. - intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", - "Frame bound value must be a constant integer.") - // We cannot use an arbitrary expression. intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", - "Frame bound value must be a constant integer.") + "Frame bound value must be a literal.") } test("row constructor") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 6dad097041a15..c7f39ae18162e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -243,7 +243,7 @@ class PlanParserSuite extends AnalysisTest { val sql = "select * from t" val plan = table("t").select(star()) val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), - SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + SpecifiedWindowFrame(RowFrame, -Literal(1), Literal(1))) // Test window resolution. val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) @@ -495,6 +495,17 @@ class PlanParserSuite extends AnalysisTest { .select(star())) } + test("SPARK-20962 Support subquery column aliases in FROM clause") { + assertEqual( + "SELECT * FROM (SELECT a AS x, b AS y FROM t) t(col1, col2)", + UnresolvedSubqueryColumnAliases( + Seq("col1", "col2"), + SubqueryAlias( + "t", + UnresolvedRelation(TableIdentifier("t")).select('a.as("x"), 'b.as("y"))) + ).select(star())) + } + test("inline table") { assertEqual("values 1, 2, 3, 4", UnresolvedInlineTable(Seq("col1"), Seq(1, 2, 3, 4).map(x => Seq(Literal(x))))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 819078218c546..4fc947a88f6b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -436,21 +436,22 @@ class TreeNodeSuite extends SparkFunSuite { "bucketColumnNames" -> "[bucket]", "sortColumnNames" -> "[sort]")) - // Converts FrameBoundary to JSON - assertJSON( - ValueFollowing(3), - JObject( - "product-class" -> classOf[ValueFollowing].getName, - "value" -> 3)) - // Converts WindowFrame to JSON assertJSON( - SpecifiedWindowFrame(RowFrame, UnboundedFollowing, CurrentRow), - JObject( - "product-class" -> classOf[SpecifiedWindowFrame].getName, - "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), - "frameStart" -> JObject("object" -> JString(UnboundedFollowing.getClass.getName)), - "frameEnd" -> JObject("object" -> JString(CurrentRow.getClass.getName)))) + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow), + List( + JObject( + "class" -> classOf[SpecifiedWindowFrame].getName, + "num-children" -> 2, + "frameType" -> JObject("object" -> JString(RowFrame.getClass.getName)), + "lower" -> 0, + "upper" -> 1), + JObject( + "class" -> UnboundedPreceding.getClass.getName, + "num-children" -> 0), + JObject( + "class" -> CurrentRow.getClass.getName, + "num-children" -> 0))) // Converts Partitioning to JSON assertJSON( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index c4635c8f126af..193826d66be26 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -142,9 +142,11 @@ class DataTypeSuite extends SparkFunSuite { val right = StructType( StructField("b", LongType) :: Nil) - intercept[SparkException] { + val message = intercept[SparkException] { left.merge(right) - } + }.getMessage + assert(message.equals("Failed to merge fields 'b' and 'b'. " + + "Failed to merge incompatible data types FloatType and LongType")) } test("existsRecursively") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 255c4064eb574..0fcda46c9b3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -499,7 +499,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
    *
  • `compression` (default is the value specified in `spark.sql.parquet.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive - * shorten names(none, `snappy`, `gzip`, and `lzo`). This will override + * shorten names(`none`, `snappy`, `gzip`, and `lzo`). This will override * `spark.sql.parquet.compression.codec`.
  • *
* diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index 1820cb0ef540b..0766e37826cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.types.IntegerType /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -109,46 +108,50 @@ case class WindowExec( * * This method uses Code Generation. It can only be used on the executor side. * - * @param frameType to evaluate. This can either be Row or Range based. - * @param offset with respect to the row. + * @param frame to evaluate. This can either be a Row or Range frame. + * @param bound with respect to the row. * @return a bound ordering object. */ - private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = { - frameType match { - case RangeFrame => - val (exprs, current, bound) = if (offset == 0) { - // Use the entire order expression when the offset is 0. - val exprs = orderSpec.map(_.child) - val buildProjection = () => newMutableProjection(exprs, child.output) - (orderSpec, buildProjection(), buildProjection()) - } else if (orderSpec.size == 1) { - // Use only the first order expression when the offset is non-null. - val sortExpr = orderSpec.head - val expr = sortExpr.child - // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output) - // Flip the sign of the offset when processing the order is descending - val boundOffset = sortExpr.direction match { - case Descending => -offset - case Ascending => offset - } - // Create the projection which returns the current 'value' modified by adding the offset. - val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output) - (sortExpr :: Nil, current, bound) - } else { - sys.error("Non-Zero range offsets are not supported for windows " + - "with multiple order expressions.") + private[this] def createBoundOrdering(frame: FrameType, bound: Expression): BoundOrdering = { + (frame, bound) match { + case (RowFrame, CurrentRow) => + RowBoundOrdering(0) + + case (RowFrame, IntegerLiteral(offset)) => + RowBoundOrdering(offset) + + case (RangeFrame, CurrentRow) => + val ordering = newOrdering(orderSpec, child.output) + RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection) + + case (RangeFrame, offset: Expression) if orderSpec.size == 1 => + // Use only the first order expression when the offset is non-null. + val sortExpr = orderSpec.head + val expr = sortExpr.child + + // Create the projection which returns the current 'value'. + val current = newMutableProjection(expr :: Nil, child.output) + + // Flip the sign of the offset when processing the order is descending + val boundOffset = sortExpr.direction match { + case Descending => UnaryMinus(offset) + case Ascending => offset } + + // Create the projection which returns the current 'value' modified by adding the offset. + val boundExpr = Add(expr, Cast(boundOffset, expr.dataType)) + val bound = newMutableProjection(boundExpr :: Nil, child.output) + // Construct the ordering. This is used to compare the result of current value projection // to the result of bound value projection. This is done manually because we want to use // Code Generation (if it is enabled). - val sortExprs = exprs.zipWithIndex.map { case (e, i) => - SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction) - } - val ordering = newOrdering(sortExprs, Nil) + val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil + val ordering = newOrdering(boundSortExprs, Nil) RangeBoundOrdering(ordering, current, bound) - case RowFrame => RowBoundOrdering(offset) + + case (RangeFrame, _) => + sys.error("Non-Zero range offsets are not supported for windows " + + "with multiple order expressions.") } } @@ -157,13 +160,13 @@ case class WindowExec( * [[WindowExpression]]s and factory function for the WindowFrameFunction. */ private[this] lazy val windowFrameExpressionFactoryPairs = { - type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type FrameKey = (String, FrameType, Expression, Expression) type ExpressionBuffer = mutable.Buffer[Expression] val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] // Add a function and its function to the map for a given frame. def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { - val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val key = (tpe, fr.frameType, fr.lower, fr.upper) val (es, fns) = framedFunctions.getOrElseUpdate( key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) es += e @@ -203,7 +206,7 @@ case class WindowExec( // Create the factory val factory = key match { // Offset Frame - case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + case ("OFFSET", _, IntegerLiteral(offset), _) => target: InternalRow => new OffsetWindowFunctionFrame( target, @@ -215,38 +218,38 @@ case class WindowExec( newMutableProjection(expressions, schema, subexpressionEliminationEnabled), offset) + // Entire Partition Frame. + case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) => + target: InternalRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + // Growing Frame. - case ("AGGREGATE", frameType, None, Some(high)) => + case ("AGGREGATE", frameType, UnboundedPreceding, upper) => target: InternalRow => { new UnboundedPrecedingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, high)) + createBoundOrdering(frameType, upper)) } // Shrinking Frame. - case ("AGGREGATE", frameType, Some(low), None) => + case ("AGGREGATE", frameType, lower, UnboundedFollowing) => target: InternalRow => { new UnboundedFollowingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low)) + createBoundOrdering(frameType, lower)) } // Moving Frame. - case ("AGGREGATE", frameType, Some(low), Some(high)) => + case ("AGGREGATE", frameType, lower, upper) => target: InternalRow => { new SlidingWindowFunctionFrame( target, processor, - createBoundOrdering(frameType, low), - createBoundOrdering(frameType, high)) - } - - // Entire Partition Frame. - case ("AGGREGATE", frameType, None, None) => - target: InternalRow => { - new UnboundedWindowFunctionFrame(target, processor) + createBoundOrdering(frameType, lower), + createBoundOrdering(frameType, upper)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index f653890f6c7ba..f8b404de77a4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.sql.Column +import org.apache.spark.sql.{AnalysisException, Column} import org.apache.spark.sql.catalyst.expressions._ /** @@ -123,7 +123,24 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { - between(RowFrame, start, end) + val boundaryStart = start match { + case 0 => CurrentRow + case Long.MinValue => UnboundedPreceding + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary start is not a valid integer: $x") + } + + val boundaryEnd = end match { + case 0 => CurrentRow + case Long.MaxValue => UnboundedFollowing + case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case x => throw new AnalysisException(s"Boundary end is not a valid integer: $x") + } + + new WindowSpec( + partitionSpec, + orderSpec, + SpecifiedWindowFrame(RowFrame, boundaryStart, boundaryEnd)) } /** @@ -174,28 +191,22 @@ class WindowSpec private[sql]( */ // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { - between(RangeFrame, start, end) - } - - private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => CurrentRow case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case x => Literal(x) } val boundaryEnd = end match { case 0 => CurrentRow case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end.toInt) - case x if x > 0 => ValueFollowing(end.toInt) + case x => Literal(x) } new WindowSpec( partitionSpec, orderSpec, - SpecifiedWindowFrame(typ, boundaryStart, boundaryEnd)) + SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql index 2b5b692d29ef4..f1461032065ad 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql @@ -23,3 +23,7 @@ SELECT float(1), double(1), decimal(1); SELECT date("2014-04-04"), timestamp(date("2014-04-04")); -- error handling: only one argument SELECT string(1, 2); + +-- SPARK-21555: RuntimeReplaceable used in group by +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st); +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value"); diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql index c90a9c7f85587..85481cbbf9377 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-aliases.sql @@ -15,3 +15,6 @@ SELECT * FROM testData AS t(col1); -- Check alias duplication SELECT a AS col1, b AS col2 FROM testData AS t(c, d); + +-- Subquery aliases in FROM clause +SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2); diff --git a/sql/core/src/test/resources/sql-tests/inputs/window.sql b/sql/core/src/test/resources/sql-tests/inputs/window.sql index c800fc3d49891..342e5719e9a60 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/window.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/window.sql @@ -1,24 +1,44 @@ -- Test data. CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES -(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) -AS testData(val, cate); +(null, 1L, "a"), (1, 1L, "a"), (1, 2L, "a"), (2, 2147483650L, "a"), (1, null, "b"), (2, 3L, "b"), +(3, 2147483650L, "b"), (null, null, null), (3, 1L, null) +AS testData(val, val_long, cate); -- RowsBetween SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val ROWS CURRENT ROW) FROM testData ORDER BY cate, val; SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val ROWS BETWEEN UNBOUNDED PRECEDING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; -- RangeBetween SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData ORDER BY cate, val; SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long; -- RangeBetween with reverse OrderBy SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +-- Invalid window frame +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_date +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val; +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val; + + -- Window functions SELECT val, cate, max(val) OVER w AS max, diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 732b11050f461..e035505f15d28 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 13 +-- Number of queries: 15 -- !query 0 @@ -122,3 +122,19 @@ struct<> -- !query 12 output org.apache.spark.sql.AnalysisException Function string accepts only one argument; line 1 pos 7 + + +-- !query 13 +CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st) +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value") +-- !query 14 schema +struct +-- !query 14 output +gamma 1 diff --git a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out index 7abbcd834a523..4459f3186c77b 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-aliases.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 7 +-- Number of queries: 8 -- !query 0 @@ -61,3 +61,11 @@ struct<> -- !query 6 output org.apache.spark.sql.AnalysisException cannot resolve '`a`' given input columns: [t.c, t.d]; line 1 pos 7 + + +-- !query 7 +SELECT * FROM (SELECT 1 AS a, 1 AS b) t(col1, col2) +-- !query 7 schema +struct +-- !query 7 output +1 1 diff --git a/sql/core/src/test/resources/sql-tests/results/window.sql.out b/sql/core/src/test/resources/sql-tests/results/window.sql.out index aa5856138ed81..97511068b323c 100644 --- a/sql/core/src/test/resources/sql-tests/results/window.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/window.sql.out @@ -1,11 +1,12 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 11 +-- Number of queries: 19 -- !query 0 CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES -(null, "a"), (1, "a"), (1, "a"), (2, "a"), (1, "b"), (2, "b"), (3, "b"), (null, null), (3, null) -AS testData(val, cate) +(null, 1L, "a"), (1, 1L, "a"), (1, 2L, "a"), (2, 2147483650L, "a"), (1, null, "b"), (2, 3L, "b"), +(3, 2147483650L, "b"), (null, null, null), (3, 1L, null) +AS testData(val, val_long, cate) -- !query 0 schema struct<> -- !query 0 output @@ -47,11 +48,21 @@ NULL a 1 -- !query 3 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +ROWS BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN CURRENT ROW AND 2147483648L FOLLOWING' due to data type mismatch: The data type of the upper bound 'LongType does not match the expected data type 'IntegerType'.; line 1 pos 41 + + +-- !query 4 SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val RANGE 1 PRECEDING) FROM testData ORDER BY cate, val --- !query 3 schema +-- !query 4 schema struct --- !query 3 output +-- !query 4 output NULL NULL 0 3 NULL 1 NULL a 0 @@ -63,12 +74,12 @@ NULL a 0 3 b 2 --- !query 4 +-- !query 5 SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val --- !query 4 schema +-- !query 5 schema struct --- !query 4 output +-- !query 5 output NULL NULL NULL 3 NULL 3 NULL a NULL @@ -80,12 +91,29 @@ NULL a NULL 3 b 3 --- !query 5 +-- !query 6 +SELECT val_long, cate, sum(val_long) OVER(PARTITION BY cate ORDER BY val_long +RANGE BETWEEN CURRENT ROW AND 2147483648 FOLLOWING) FROM testData ORDER BY cate, val_long +-- !query 6 schema +struct +-- !query 6 output +NULL NULL NULL +1 NULL 1 +1 a 4 +1 a 4 +2 a 2147483652 +2147483650 a 2147483650 +NULL b NULL +3 b 2147483653 +2147483650 b 2147483650 + + +-- !query 7 SELECT val, cate, sum(val) OVER(PARTITION BY cate ORDER BY val DESC RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val --- !query 5 schema +-- !query 7 schema struct --- !query 5 output +-- !query 7 output NULL NULL NULL 3 NULL 3 NULL a NULL @@ -97,7 +125,73 @@ NULL a NULL 3 b 5 --- !query 6 +-- !query 8 +SELECT val, cate, count(val) OVER(PARTITION BY cate +ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +cannot resolve 'ROWS BETWEEN UNBOUNDED FOLLOWING AND 1 FOLLOWING' due to data type mismatch: Window frame upper bound '1' does not followes the lower bound 'unboundedfollowing$()'.; line 1 pos 33 + + +-- !query 9 +SELECT val, cate, count(val) OVER(PARTITION BY cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame cannot be used in an unordered window specification.; line 1 pos 33 + + +-- !query 10 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val, cate +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY testdata.`val` ASC NULLS FIRST, testdata.`cate` ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: A range window frame with value boundaries cannot be used in a window specification with multiple order by expressions: val#x ASC NULLS FIRST,cate#x ASC NULLS FIRST; line 1 pos 33 + + +-- !query 11 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY current_date +RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING) FROM testData ORDER BY cate, val +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '(PARTITION BY testdata.`cate` ORDER BY current_date() ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 1 FOLLOWING)' due to data type mismatch: The data type 'DateType' used in the order specification does not match the data type 'IntegerType' which is used in the range frame.; line 1 pos 33 + + +-- !query 12 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING) FROM testData ORDER BY cate, val +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve 'RANGE BETWEEN 1 FOLLOWING AND 1 PRECEDING' due to data type mismatch: The lower bound of a window frame must be less than or equal to the upper bound; line 1 pos 33 + + +-- !query 13 +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +Frame bound value must be a literal.(line 2, pos 30) + +== SQL == +SELECT val, cate, count(val) OVER(PARTITION BY cate ORDER BY val +RANGE BETWEEN CURRENT ROW AND current_date PRECEDING) FROM testData ORDER BY cate, val +------------------------------^^^ + + +-- !query 14 SELECT val, cate, max(val) OVER w AS max, min(val) OVER w AS min, @@ -124,9 +218,9 @@ approx_count_distinct(val) OVER w AS approx_count_distinct FROM testData WINDOW w AS (PARTITION BY cate ORDER BY val) ORDER BY cate, val --- !query 6 schema +-- !query 14 schema struct --- !query 6 output +-- !query 14 output NULL NULL NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.5 0.0 1 1 NULL NULL 0 3 NULL 3 3 3 1 3 3.0 NaN NULL 3 NULL 3 3 3 2 2 1.0 1.0 2 2 0.0 NaN 1 NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0.0 1 1 NULL NULL 0 @@ -138,11 +232,11 @@ NULL a NULL NULL NULL 0 NULL NULL NULL NULL NULL NULL NULL NULL NULL 1 1 0.25 0. 3 b 3 1 1 3 6 2.0 1.0 1 1 1 3 3 3 3 3 1.0 1.0 2 3 0.6666666666666666 1.0 3 --- !query 7 +-- !query 15 SELECT val, cate, avg(null) OVER(PARTITION BY cate ORDER BY val) FROM testData ORDER BY cate, val --- !query 7 schema +-- !query 15 schema struct --- !query 7 output +-- !query 15 output NULL NULL NULL 3 NULL NULL NULL a NULL @@ -154,20 +248,20 @@ NULL a NULL 3 b NULL --- !query 8 +-- !query 16 SELECT val, cate, row_number() OVER(PARTITION BY cate) FROM testData ORDER BY cate, val --- !query 8 schema +-- !query 16 schema struct<> --- !query 8 output +-- !query 16 output org.apache.spark.sql.AnalysisException Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table; --- !query 9 +-- !query 17 SELECT val, cate, sum(val) OVER(), avg(val) OVER() FROM testData ORDER BY cate, val --- !query 9 schema +-- !query 17 schema struct --- !query 9 output +-- !query 17 output NULL NULL 13 1.8571428571428572 3 NULL 13 1.8571428571428572 NULL a 13 1.8571428571428572 @@ -179,7 +273,7 @@ NULL a 13 1.8571428571428572 3 b 13 1.8571428571428572 --- !query 10 +-- !query 18 SELECT val, cate, first_value(false) OVER w AS first_value, first_value(true, true) OVER w AS first_value_ignore_null, @@ -190,9 +284,9 @@ last_value(false, false) OVER w AS last_value_contain_null FROM testData WINDOW w AS () ORDER BY cate, val --- !query 10 schema +-- !query 18 schema struct --- !query 10 output +-- !query 18 output NULL NULL false true false false true false 3 NULL false true false false true false NULL a false true false false true false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 204858fa29787..9806e57f08744 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -151,6 +151,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(2.0d), Row(2.0d))) } + test("row between should accept integer values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483647))), + Seq(Row(1, 3), Row(1, 4), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + + val e = intercept[AnalysisException]( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rowsBetween(0, 2147483648L)))) + assert(e.message.contains("Boundary end is not a valid integer: 2147483648")) + } + + test("range between should accept integer/long values as boundary") { + val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), + (3L, "2"), (2L, "1"), (2147483650L, "2")) + .toDF("key", "value") + df.createOrReplaceTempView("window_table") + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(0, 2147483648L))), + Seq(Row(1, 3), Row(1, 3), Row(2, 2), Row(3, 2), Row(2147483650L, 1), Row(2147483650L, 1)) + ) + checkAnswer( + df.select( + $"key", + count("key").over( + Window.partitionBy($"value").orderBy($"key").rangeBetween(-2147483649L, 0))), + Seq(Row(1, 2), Row(1, 2), Row(2, 3), Row(2147483650L, 2), Row(2147483650L, 4), Row(3, 1)) + ) + } + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.createOrReplaceTempView("window_table") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 50d8e3024598d..d194f58cd1cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -127,9 +127,10 @@ class UnsafeFixedWidthAggregationMapSuite PAGE_SIZE_BYTES ) val groupKey = InternalRow(UTF8String.fromString("cats")) + val row = map.getAggregationBuffer(groupKey) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) - assert(map.getAggregationBuffer(groupKey) != null) + assert(row != null) val iter = map.iterator() assert(iter.next()) iter.getKey.getString(0) should be ("cats") @@ -138,7 +139,7 @@ class UnsafeFixedWidthAggregationMapSuite // Modifications to rows retrieved from the map should update the values in the map iter.getValue.setInt(0, 42) - map.getAggregationBuffer(groupKey).getInt(0) should be (42) + row.getInt(0) should be (42) map.free() } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index 17589cf44b998..f517bffccdf31 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -103,7 +103,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } val content = -
SQL Statistics
++ +
SQL Statistics ({numStatement})
++
    {table.getOrElse("No statistics have been generated yet.")} @@ -164,7 +164,7 @@ private[ui] class ThriftServerPage(parent: ThriftServerTab) extends WebUIPage("" } val content = -
    Session Statistics
    ++ +
    Session Statistics ({numBatches})
    ++
      {table.getOrElse("No statistics have been generated yet.")} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala index 149ce1e195111..90f90599d5bf4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala @@ -98,27 +98,27 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL( WindowSpecDefinition('a.int :: Nil, Nil, frame), - s"(PARTITION BY `a` $frame)" + s"(PARTITION BY `a` ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, Nil, frame), - s"(PARTITION BY `a`, `b` $frame)" + s"(PARTITION BY `a`, `b` ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST ${frame.sql})" ) checkSQL( WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), - s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST $frame)" + s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST ${frame.sql})" ) checkSQL( WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), - s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST $frame)" + s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST ${frame.sql})" ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 905b1c52afa69..b8a5a96faf15c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -164,6 +164,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( /** Clear the old time-to-files mappings along with old RDDs */ protected[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) batchTimeToSelectedFiles.synchronized { val oldFiles = batchTimeToSelectedFiles.filter(_._1 < (time - rememberDuration)) batchTimeToSelectedFiles --= oldFiles.keys