diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 4db9cc30fb0c1..306a9b8676539 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -46,15 +46,16 @@ setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj" #' @note NaiveBayesModel since 2.0.0 setClass("NaiveBayesModel", representation(jobj = "jobj")) -#' linear SVM Model +#' Linear SVM Model #' -#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package +#' Fits a linear SVM model against a SparkDataFrame, similar to svm in e1071 package. +#' Currently only supports binary classification model with linear kernel. #' Users can print, make predictions on the produced model and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam The regularization parameter. +#' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. #' @param standardization Whether to standardize the training features before fitting the model. The coefficients @@ -111,10 +112,10 @@ setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formu new("LinearSVCModel", jobj = jobj) }) -# Predicted values based on an LinearSVCModel model +# Predicted values based on a LinearSVCModel model #' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LinearSVCModel. +#' @return \code{predict} returns the predicted values based on a LinearSVCModel. #' @rdname spark.svmLinear #' @aliases predict,LinearSVCModel,SparkDataFrame-method #' @export @@ -124,13 +125,12 @@ setMethod("predict", signature(object = "LinearSVCModel"), predict_internal(object, newData) }) -# Get the summary of an LinearSVCModel +# Get the summary of a LinearSVCModel -#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. +#' @param object a LinearSVCModel fitted by \code{spark.svmLinear}. #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list includes \code{coefficients} (coefficients of the fitted model), -#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), -#' \code{numFeatures} (number of features). +#' \code{numClasses} (number of classes), \code{numFeatures} (number of features). #' @rdname spark.svmLinear #' @aliases summary,LinearSVCModel-method #' @export @@ -138,22 +138,14 @@ setMethod("predict", signature(object = "LinearSVCModel"), setMethod("summary", signature(object = "LinearSVCModel"), function(object) { jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - coefficients <- callJMethod(jobj, "coefficients") - nCol <- length(coefficients) / length(features) - coefficients <- matrix(unlist(coefficients), ncol = nCol) - intercept <- callJMethod(jobj, "intercept") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) numClasses <- callJMethod(jobj, "numClasses") numFeatures <- callJMethod(jobj, "numFeatures") - if (nCol == 1) { - colnames(coefficients) <- c("Estimate") - } else { - colnames(coefficients) <- unlist(labels) - } - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients, intercept = intercept, - numClasses = numClasses, numFeatures = numFeatures) + list(coefficients = coefficients, numClasses = numClasses, numFeatures = numFeatures) }) # Save fitted LinearSVCModel to the input path diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index d29af00affb98..ea45e394500e8 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -907,3 +907,19 @@ basenameSansExtFromUrl <- function(url) { isAtomicLengthOne <- function(x) { is.atomic(x) && length(x) == 1 } + +is_cran <- function() { + !identical(Sys.getenv("NOT_CRAN"), "true") +} + +is_windows <- function() { + .Platform$OS.type == "windows" +} + +hadoop_home_set <- function() { + !identical(Sys.getenv("HADOOP_HOME"), "") +} + +not_cran_or_windows_with_hadoop <- function() { + !is_cran() && (!is_windows() || hadoop_home_set()) +} diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index f3eaeb381afc4..c1c746828d24b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -38,9 +38,8 @@ test_that("spark.svmLinear", { expect_true(class(summary$coefficients[, 1]) == "numeric") coefs <- summary$coefficients[, "Estimate"] - expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expected_coefs <- c(-0.06004978, -0.1563083, -0.460648, 0.2276626, 1.055085) expect_true(all(abs(coefs - expected_coefs) < 0.1)) - expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) # Test prediction with string label prediction <- predict(model, training) @@ -50,15 +49,17 @@ test_that("spark.svmLinear", { expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) # Test model save and load - modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # Test prediction with numeric label label <- c(0.0, 0.0, 0.0, 1.0, 1.0) @@ -128,15 +129,17 @@ test_that("spark.logit", { expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) # Test model save and load - modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - coefs <- summary(model)$coefficients - coefs2 <- summary(model2)$coefficients - expect_equal(coefs, coefs2) - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + } # R code to reproduce the result. # nolint start @@ -243,19 +246,21 @@ test_that("spark.mlp", { expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$numOfInputs, 4) - expect_equal(summary2$numOfOutputs, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + } # Test default parameter model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) @@ -354,16 +359,18 @@ test_that("spark.naiveBayes", { "Yes", "Yes", "No", "No")) # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + } # Test e1071::naiveBayes if (requireNamespace("e1071", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index df8e5968b27f4..8f71de1cbc7b5 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -53,18 +53,20 @@ test_that("spark.bisectingKmeans", { c(0, 1, 2, 3)) # Test model save/load - modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } }) test_that("spark.gaussianMixture", { @@ -125,18 +127,20 @@ test_that("spark.gaussianMixture", { expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) + + unlink(modelPath) + } }) test_that("spark.kmeans", { @@ -171,18 +175,20 @@ test_that("spark.kmeans", { expect_true(class(summary.model$coefficients[1, ]) == "numeric") # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + } # Test Kmeans on dataset that is sensitive to seed value col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) @@ -236,22 +242,24 @@ test_that("spark.lda with libsvm", { expect_true(logPrior <= 0 & !is.na(logPrior)) # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_true(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) - expect_equal(logPrior, stats2$logPrior) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_true(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) + expect_equal(logPrior, stats2$logPrior) + + unlink(modelPath) + } }) test_that("spark.lda with text input", { diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R index 1fa5375f9da31..4e10ca1e4f50b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_fpm.R +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -62,15 +62,17 @@ test_that("spark.fpGrowth", { expect_equivalent(expected_predictions, collect(predict(model, new_data))) - modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") - write.ml(model, modelPath, overwrite = TRUE) - loaded_model <- read.ml(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) - expect_equivalent( - itemsets, - collect(spark.freqItemsets(loaded_model))) + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) - unlink(modelPath) + unlink(modelPath) + } model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) expect_equal( diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R index e3e2b15c71361..cc8064f88d27a 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R +++ b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R @@ -37,29 +37,31 @@ test_that("spark.als", { tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - unlink(modelPath) + unlink(modelPath) + } }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 44c98be906d81..b05fdd350ca28 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -401,14 +401,16 @@ test_that("spark.isoreg", { expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) # Test model save/load - modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) + } }) test_that("spark.survreg", { @@ -450,17 +452,19 @@ test_that("spark.survreg", { 2.390146, 2.891269, 2.891269), tolerance = 1e-4) # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + } # Test survival::survreg if (requireNamespace("survival", quietly = TRUE)) { diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R index b283e734cec53..5fd6a38ecb4fa 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_tree.R +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -44,21 +44,23 @@ test_that("spark.gbt", { expect_equal(stats$numFeatures, 6) expect_equal(length(stats$treeWeights), 20) - modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } # classification # label must be binary - GBTClassifier currently only supports binary classification. @@ -76,17 +78,19 @@ test_that("spark.gbt", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) df <- suppressWarnings(createDataFrame(iris2)) @@ -136,21 +140,23 @@ test_that("spark.randomForest", { expect_equal(stats$numTrees, 20) expect_equal(stats$maxDepth, 5) - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$maxDepth, stats2$maxDepth) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + } # classification data <- suppressWarnings(createDataFrame(iris)) @@ -168,17 +174,19 @@ test_that("spark.randomForest", { expect_equal(length(grep("setosa", predictions)), 50) expect_equal(length(grep("versicolor", predictions)), 50) - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) + if (not_cran_or_windows_with_hadoop()) { + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + } # Test numeric response variable labelToIndex <- function(species) { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index b633b78d5bb4d..9fc6e5dabecc3 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -61,7 +61,11 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR filesBefore <- list.files(path = sparkRDir, all.files = TRUE) -sparkSession <- sparkR.session(master = sparkRTestMaster) +sparkSession <- if (not_cran_or_windows_with_hadoop()) { + sparkR.session(master = sparkRTestMaster) + } else { + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + } sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -326,51 +330,53 @@ test_that("createDataFrame uses files for large objects", { }) test_that("read/write csv as DataFrame", { - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "NA,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - # default "header" is false, inferSchema to handle "year" as "int" - df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") - expect_equal(count(df), 4) - expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) - expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), - sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) - - # since "year" is "int", let's skip the NA values - withoutna <- na.omit(df, how = "any", cols = "year") - expect_equal(count(withoutna), 3) - - unlink(csvPath) - csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") - mockLinesCsv <- c("year,make,model,comment,blank", - "\"2012\",\"Tesla\",\"S\",\"No comment\",", - "1997,Ford,E350,\"Go get one now they are going fast\",", - "2015,Chevy,Volt", - "Empty,Dummy,Placeholder") - writeLines(mockLinesCsv, csvPath) - - df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") - expect_equal(count(df2), 4) - withoutna2 <- na.omit(df2, how = "any", cols = "year") - expect_equal(count(withoutna2), 3) - expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) - - # writing csv file - csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") - write.df(df2, path = csvPath2, "csv", header = "true") - df3 <- read.df(csvPath2, "csv", header = "true") - expect_equal(nrow(df3), nrow(df2)) - expect_equal(colnames(df3), colnames(df2)) - csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) - expect_equal(colnames(df3), colnames(csv)) - - unlink(csvPath) - unlink(csvPath2) + if (not_cran_or_windows_with_hadoop()) { + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "NA,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + # default "header" is false, inferSchema to handle "year" as "int" + df <- read.df(csvPath, "csv", header = "true", inferSchema = "true") + expect_equal(count(df), 4) + expect_equal(columns(df), c("year", "make", "model", "comment", "blank")) + expect_equal(sort(unlist(collect(where(df, df$year == 2015)))), + sort(unlist(list(year = 2015, make = "Chevy", model = "Volt")))) + + # since "year" is "int", let's skip the NA values + withoutna <- na.omit(df, how = "any", cols = "year") + expect_equal(count(withoutna), 3) + + unlink(csvPath) + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") + mockLinesCsv <- c("year,make,model,comment,blank", + "\"2012\",\"Tesla\",\"S\",\"No comment\",", + "1997,Ford,E350,\"Go get one now they are going fast\",", + "2015,Chevy,Volt", + "Empty,Dummy,Placeholder") + writeLines(mockLinesCsv, csvPath) + + df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.strings = "Empty") + expect_equal(count(df2), 4) + withoutna2 <- na.omit(df2, how = "any", cols = "year") + expect_equal(count(withoutna2), 3) + expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0) + + # writing csv file + csvPath2 <- tempfile(pattern = "csvtest2", fileext = ".csv") + write.df(df2, path = csvPath2, "csv", header = "true") + df3 <- read.df(csvPath2, "csv", header = "true") + expect_equal(nrow(df3), nrow(df2)) + expect_equal(colnames(df3), colnames(df2)) + csv <- read.csv(file = list.files(csvPath2, pattern = "^part", full.names = T)[[1]]) + expect_equal(colnames(df3), colnames(csv)) + + unlink(csvPath) + unlink(csvPath2) + } }) test_that("Support other types for options", { @@ -601,48 +607,50 @@ test_that("Collect DataFrame with complex types", { }) test_that("read/write json files", { - # Test read.df - df <- read.df(jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test read.df with a user defined schema - schema <- structType(structField("name", type = "string"), - structField("age", type = "double")) - - df1 <- read.df(jsonPath, "json", schema) - expect_is(df1, "SparkDataFrame") - expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) - - # Test loadDF - df2 <- loadDF(jsonPath, "json", schema) - expect_is(df2, "SparkDataFrame") - expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) - - # Test read.json - df <- read.json(jsonPath) - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - - # Test write.df - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") - write.df(df, jsonPath2, "json", mode = "overwrite") - - # Test write.json - jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") - write.json(df, jsonPath3) - - # Test read.json()/jsonFile() works with multiple input paths - jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) - expect_is(jsonDF1, "SparkDataFrame") - expect_equal(count(jsonDF1), 6) - # Suppress warnings because jsonFile is deprecated - jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) - expect_is(jsonDF2, "SparkDataFrame") - expect_equal(count(jsonDF2), 6) - - unlink(jsonPath2) - unlink(jsonPath3) + if (not_cran_or_windows_with_hadoop()) { + # Test read.df + df <- read.df(jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test read.df with a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(jsonPath, "json", schema) + expect_is(df1, "SparkDataFrame") + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Test loadDF + df2 <- loadDF(jsonPath, "json", schema) + expect_is(df2, "SparkDataFrame") + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) + + # Test read.json + df <- read.json(jsonPath) + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + + # Test write.df + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".json") + write.df(df, jsonPath2, "json", mode = "overwrite") + + # Test write.json + jsonPath3 <- tempfile(pattern = "jsonPath3", fileext = ".json") + write.json(df, jsonPath3) + + # Test read.json()/jsonFile() works with multiple input paths + jsonDF1 <- read.json(c(jsonPath2, jsonPath3)) + expect_is(jsonDF1, "SparkDataFrame") + expect_equal(count(jsonDF1), 6) + # Suppress warnings because jsonFile is deprecated + jsonDF2 <- suppressWarnings(jsonFile(c(jsonPath2, jsonPath3))) + expect_is(jsonDF2, "SparkDataFrame") + expect_equal(count(jsonDF2), 6) + + unlink(jsonPath2) + unlink(jsonPath3) + } }) test_that("read/write json files - compression option", { @@ -736,33 +744,35 @@ test_that("test cache, uncache and clearCache", { }) test_that("insertInto() on a registered table", { - df <- read.df(jsonPath, "json") - write.df(df, parquetPath, "parquet", "overwrite") - dfParquet <- read.df(parquetPath, "parquet") - - lines <- c("{\"name\":\"Bob\", \"age\":24}", - "{\"name\":\"James\", \"age\":35}") - jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - writeLines(lines, jsonPath2) - df2 <- read.df(jsonPath2, "json") - write.df(df2, parquetPath2, "parquet", "overwrite") - dfParquet2 <- read.df(parquetPath2, "parquet") - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1") - expect_equal(count(sql("select * from table1")), 5) - expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - expect_true(dropTempView("table1")) - - createOrReplaceTempView(dfParquet, "table1") - insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_equal(count(sql("select * from table1")), 2) - expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - expect_true(dropTempView("table1")) - - unlink(jsonPath2) - unlink(parquetPath2) + if (not_cran_or_windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + write.df(df, parquetPath, "parquet", "overwrite") + dfParquet <- read.df(parquetPath, "parquet") + + lines <- c("{\"name\":\"Bob\", \"age\":24}", + "{\"name\":\"James\", \"age\":35}") + jsonPath2 <- tempfile(pattern = "jsonPath2", fileext = ".tmp") + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + writeLines(lines, jsonPath2) + df2 <- read.df(jsonPath2, "json") + write.df(df2, parquetPath2, "parquet", "overwrite") + dfParquet2 <- read.df(parquetPath2, "parquet") + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1") + expect_equal(count(sql("select * from table1")), 5) + expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") + expect_true(dropTempView("table1")) + + createOrReplaceTempView(dfParquet, "table1") + insertInto(dfParquet2, "table1", overwrite = TRUE) + expect_equal(count(sql("select * from table1")), 2) + expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") + expect_true(dropTempView("table1")) + + unlink(jsonPath2) + unlink(parquetPath2) + } }) test_that("tableToDF() returns a new DataFrame", { @@ -954,14 +964,16 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", }) test_that("setCheckpointDir(), checkpoint() on a DataFrame", { - checkpointDir <- file.path(tempdir(), "cproot") - expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) - - setCheckpointDir(checkpointDir) - df <- read.json(jsonPath) - df <- checkpoint(df) - expect_is(df, "SparkDataFrame") - expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + if (not_cran_or_windows_with_hadoop()) { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + } }) test_that("schema(), dtypes(), columns(), names() return the correct values/format", { @@ -1329,45 +1341,47 @@ test_that("column calculation", { }) test_that("test HiveContext", { - setHiveContext(sc) - - schema <- structType(structField("name", "string"), structField("age", "integer"), - structField("height", "float")) - createTable("people", source = "json", schema = schema) - df <- read.df(jsonPathNa, "json", schema) - insertInto(df, "people") - expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) - sql("DROP TABLE people") - - df <- createTable("json", jsonPath, "json") - expect_is(df, "SparkDataFrame") - expect_equal(count(df), 3) - df2 <- sql("select * from json") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "json2", "json", "append", path = jsonPath2) - df3 <- sql("select * from json2") - expect_is(df3, "SparkDataFrame") - expect_equal(count(df3), 3) - unlink(jsonPath2) - - hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "hivetestbl", path = hivetestDataPath) - df4 <- sql("select * from hivetestbl") - expect_is(df4, "SparkDataFrame") - expect_equal(count(df4), 3) - unlink(hivetestDataPath) - - parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) - df5 <- sql("select * from parquetest") - expect_is(df5, "SparkDataFrame") - expect_equal(count(df5), 3) - unlink(parquetDataPath) - - unsetHiveContext() + if (not_cran_or_windows_with_hadoop()) { + setHiveContext(sc) + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") + expect_is(df, "SparkDataFrame") + expect_equal(count(df), 3) + df2 <- sql("select * from json") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "json2", "json", "append", path = jsonPath2) + df3 <- sql("select * from json2") + expect_is(df3, "SparkDataFrame") + expect_equal(count(df3), 3) + unlink(jsonPath2) + + hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "hivetestbl", path = hivetestDataPath) + df4 <- sql("select * from hivetestbl") + expect_is(df4, "SparkDataFrame") + expect_equal(count(df4), 3) + unlink(hivetestDataPath) + + parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) + df5 <- sql("select * from parquetest") + expect_is(df5, "SparkDataFrame") + expect_equal(count(df5), 3) + unlink(parquetDataPath) + + unsetHiveContext() + } }) test_that("column operators", { @@ -2420,34 +2434,36 @@ test_that("read/write ORC files - compression option", { }) test_that("read/write Parquet files", { - df <- read.df(jsonPath, "json") - # Test write.df and read.df - write.df(df, parquetPath, "parquet", mode = "overwrite") - df2 <- read.df(parquetPath, "parquet") - expect_is(df2, "SparkDataFrame") - expect_equal(count(df2), 3) - - # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile - parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") - write.parquet(df, parquetPath2) - parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - suppressWarnings(saveAsParquetFile(df, parquetPath3)) - parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) - expect_is(parquetDF, "SparkDataFrame") - expect_equal(count(parquetDF), count(df) * 2) - parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) - expect_is(parquetDF2, "SparkDataFrame") - expect_equal(count(parquetDF2), count(df) * 2) - - # Test if varargs works with variables - saveMode <- "overwrite" - mergeSchema <- "true" - parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") - write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) - - unlink(parquetPath2) - unlink(parquetPath3) - unlink(parquetPath4) + if (not_cran_or_windows_with_hadoop()) { + df <- read.df(jsonPath, "json") + # Test write.df and read.df + write.df(df, parquetPath, "parquet", mode = "overwrite") + df2 <- read.df(parquetPath, "parquet") + expect_is(df2, "SparkDataFrame") + expect_equal(count(df2), 3) + + # Test write.parquet/saveAsParquetFile and read.parquet/parquetFile + parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") + write.parquet(df, parquetPath2) + parquetPath3 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + suppressWarnings(saveAsParquetFile(df, parquetPath3)) + parquetDF <- read.parquet(c(parquetPath2, parquetPath3)) + expect_is(parquetDF, "SparkDataFrame") + expect_equal(count(parquetDF), count(df) * 2) + parquetDF2 <- suppressWarnings(parquetFile(parquetPath2, parquetPath3)) + expect_is(parquetDF2, "SparkDataFrame") + expect_equal(count(parquetDF2), count(df) * 2) + + # Test if varargs works with variables + saveMode <- "overwrite" + mergeSchema <- "true" + parquetPath4 <- tempfile(pattern = "parquetPath3", fileext = ".parquet") + write.df(df, parquetPath3, "parquet", mode = saveMode, mergeSchema = mergeSchema) + + unlink(parquetPath2) + unlink(parquetPath3) + unlink(parquetPath4) + } }) test_that("read/write Parquet files - compression option/mode", { diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 050778a895c0f..7d356e8fc1c00 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -92,6 +92,9 @@ private[deploy] object RPackageUtils extends Logging { * Exposed for testing. */ private[deploy] def checkManifestForR(jar: JarFile): Boolean = { + if (jar.getManifest == null) { + return false + } val manifest = jar.getManifest.getMainAttributes manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true" } diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index f50cb38311db2..42b8cde650390 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -243,16 +243,22 @@ private[deploy] object IvyTestUtils { withManifest: Option[Manifest] = None): File = { val jarFile = new File(dir, artifactName(artifact, useIvyLayout)) val jarFileStream = new FileOutputStream(jarFile) - val manifest = withManifest.getOrElse { - val mani = new Manifest() + val manifest: Manifest = withManifest.getOrElse { if (withR) { + val mani = new Manifest() val attr = mani.getMainAttributes attr.put(Name.MANIFEST_VERSION, "1.0") attr.put(new Name("Spark-HasRPackage"), "true") + mani + } else { + null } - mani } - val jarStream = new JarOutputStream(jarFileStream, manifest) + val jarStream = if (manifest != null) { + new JarOutputStream(jarFileStream, manifest) + } else { + new JarOutputStream(jarFileStream) + } for (file <- files) { val jarEntry = new JarEntry(file._1) diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 005587051b6ad..5e0bf6d438dc8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -133,6 +133,16 @@ class RPackageUtilsSuite } } + test("jars without manifest return false") { + IvyTestUtils.withRepository(main, None, None) { repo => + val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil, + useIvyLayout = false, withR = false, None) + val jarFile = new JarFile(jar) + assert(jarFile.getManifest == null, "jar file should have null manifest") + assert(!RPackageUtils.checkManifestForR(jarFile), "null manifest should return false") + } + } + test("SparkR zipping works properly") { val tempDir = Files.createTempDir() Utils.tryWithSafeFinally { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala index cfd043b66ed94..0dd1f1146fbf8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LinearSVCWrapper.scala @@ -38,9 +38,17 @@ private[r] class LinearSVCWrapper private ( private val svcModel: LinearSVCModel = pipeline.stages(1).asInstanceOf[LinearSVCModel] - lazy val coefficients: Array[Double] = svcModel.coefficients.toArray + lazy val rFeatures: Array[String] = if (svcModel.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } - lazy val intercept: Double = svcModel.intercept + lazy val rCoefficients: Array[Double] = if (svcModel.getFitIntercept) { + Array(svcModel.intercept) ++ svcModel.coefficients.toArray + } else { + svcModel.coefficients.toArray + } lazy val numClasses: Int = svcModel.numClasses diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d58b8acefdade..d130962c63918 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1336,7 +1336,7 @@ class Analyzer( // Category 1: // BroadcastHint, Distinct, LeafNode, Repartition, and SubqueryAlias - case _: BroadcastHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => + case _: ResolvedHint | _: Distinct | _: LeafNode | _: Repartition | _: SubqueryAlias => // Category 2: // These operators can be anywhere in a correlated subquery. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ea4560aac7259..2e3ac3e474866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -399,7 +399,7 @@ trait CheckAnalysis extends PredicateHelper { |in operator ${operator.simpleString} """.stripMargin) - case _: Hint => + case _: UnresolvedHint => throw new IllegalStateException( "Internal error: logical hint operator should have been removed during analysis") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index df688fa0e58ae..9dfd84cbc9941 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -57,11 +57,11 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { case u: UnresolvedRelation if toBroadcast.exists(resolver(_, u.tableIdentifier.table)) => - BroadcastHint(plan) + ResolvedHint(plan, isBroadcastable = Option(true)) case r: SubqueryAlias if toBroadcast.exists(resolver(_, r.alias)) => - BroadcastHint(plan) + ResolvedHint(plan, isBroadcastable = Option(true)) - case _: BroadcastHint | _: View | _: With | _: SubqueryAlias => + case _: ResolvedHint | _: View | _: With | _: SubqueryAlias => // Don't traverse down these nodes. // For an existing broadcast hint, there is no point going down (if we do, we either // won't change the structure, or will introduce another broadcast hint that is useless. @@ -85,10 +85,10 @@ object ResolveHints { } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => + case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => if (h.parameters.isEmpty) { // If there is no table alias specified, turn the entire subtree into a BroadcastHint. - BroadcastHint(h.child) + ResolvedHint(h.child, isBroadcastable = Option(true)) } else { // Otherwise, find within the subtree query plans that should be broadcasted. applyBroadcastHint(h.child, h.parameters.toSet) @@ -102,7 +102,7 @@ object ResolveHints { */ object RemoveAllHints extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case h: Hint => h.child + case h: UnresolvedHint => h.child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index cc0cbba275b81..2f328ccc49451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -203,6 +203,8 @@ case class BucketSpec( * sensitive schema was unable to be read from the table properties. * Used to trigger case-sensitive schema inference at query time, when * configured. + * @param ignoredProperties is a list of table properties that are used by the underlying table + * but ignored by Spark SQL yet. */ case class CatalogTable( identifier: TableIdentifier, @@ -221,7 +223,8 @@ case class CatalogTable( comment: Option[String] = None, unsupportedFeatures: Seq[String] = Seq.empty, tracksPartitionsInCatalog: Boolean = false, - schemaPreservesCase: Boolean = true) { + schemaPreservesCase: Boolean = true, + ignoredProperties: Map[String, String] = Map.empty) { import CatalogTable._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1802cd4bb131b..ae2f6bfa94ae7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -862,7 +862,7 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { // Note that some operators (e.g. project, aggregate, union) are being handled separately // (earlier in this rule). case _: AppendColumns => true - case _: BroadcastHint => true + case _: ResolvedHint => true case _: Distinct => true case _: Generate => true case _: Pivot => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index d3ef5ea840919..8931eb2c8f3b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -478,7 +478,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { case _: Distinct => true case _: AppendColumns => true case _: AppendColumnsWithObject => true - case _: BroadcastHint => true + case _: ResolvedHint => true case _: RepartitionByExpression => true case _: Repartition => true case _: Sort => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f033fd4834c96..7d2e3a6fe7580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -533,13 +533,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } /** - * Add a [[Hint]] to a logical plan. + * Add a [[UnresolvedHint]] to a logical plan. */ private def withHints( ctx: HintContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { val stmt = ctx.hintStatement - Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) + UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index d39b0ef7e1d8a..ef925f92ecc7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -65,8 +65,8 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) - case BroadcastHint(child) => - collectProjectsAndFilters(child) + case h: ResolvedHint => + collectProjectsAndFilters(h.child) case other => (None, Nil, other, Map.empty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 3d4efef953a64..81bb374cb0500 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -68,6 +68,11 @@ case class Statistics( s"isBroadcastable=$isBroadcastable" ).filter(_.nonEmpty).mkString(", ") } + + /** Must be called when computing stats for a join operator to reset hints. */ + def resetHintsForJoin(): Statistics = copy( + isBroadcastable = false + ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index d291ca0020838..9f34b371740bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -364,7 +364,7 @@ case class Join( case _ => // Make sure we don't propagate isBroadcastable in other joins, because // they could explode the size. - super.computeStats(conf).copy(isBroadcastable = false) + super.computeStats(conf).resetHintsForJoin() } if (conf.cboEnabled) { @@ -375,26 +375,6 @@ case class Join( } } -/** - * A hint for the optimizer that we should broadcast the `child` if used in a join operator. - */ -case class BroadcastHint(child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - - // set isBroadcastable to true so the child will be broadcasted - override def computeStats(conf: SQLConf): Statistics = - child.stats(conf).copy(isBroadcastable = true) -} - -/** - * A general hint for the child. This node will be eliminated post analysis. - * A pair of (name, parameters). - */ -case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode { - override lazy val resolved: Boolean = false - override def output: Seq[Attribute] = child.output -} - /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala new file mode 100644 index 0000000000000..9bcbfbb4d1397 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.internal.SQLConf + +/** + * A general hint for the child that is not yet resolved. This node is generated by the parser and + * should be removed This node will be eliminated post analysis. + * A pair of (name, parameters). + */ +case class UnresolvedHint(name: String, parameters: Seq[String], child: LogicalPlan) + extends UnaryNode { + + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} + +/** + * A resolved hint node. The analyzer should convert all [[UnresolvedHint]] into [[ResolvedHint]]. + */ +case class ResolvedHint( + child: LogicalPlan, + isBroadcastable: Option[Boolean] = None) + extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + override def computeStats(conf: SQLConf): Statistics = { + val stats = child.stats(conf) + isBroadcastable.map(x => stats.copy(isBroadcastable = x)).getOrElse(stats) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index d101e2227462d..bb914e11a139a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -28,68 +28,70 @@ class ResolveHintsSuite extends AnalysisTest { test("invalid hints should be ignored") { checkAnalysis( - Hint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), + UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), testRelation, caseSensitive = false) } test("case-sensitive or insensitive parameters") { checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + ResolvedHint(testRelation, isBroadcastable = Option(true)), caseSensitive = false) checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("TaBlE")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), + ResolvedHint(testRelation, isBroadcastable = Option(true)), caseSensitive = false) checkAnalysis( - Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + ResolvedHint(testRelation, isBroadcastable = Option(true)), caseSensitive = true) checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("TaBlE")), + UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")), testRelation, caseSensitive = true) } test("multiple broadcast hint aliases") { checkAnalysis( - Hint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), - Join(BroadcastHint(testRelation), BroadcastHint(testRelation2), Inner, None), + UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), + Join(ResolvedHint(testRelation, isBroadcastable = Option(true)), + ResolvedHint(testRelation2, isBroadcastable = Option(true)), Inner, None), caseSensitive = false) } test("do not traverse past existing broadcast hints") { checkAnalysis( - Hint("MAPJOIN", Seq("table"), BroadcastHint(table("table").where('a > 1))), - BroadcastHint(testRelation.where('a > 1)).analyze, + UnresolvedHint("MAPJOIN", Seq("table"), + ResolvedHint(table("table").where('a > 1), isBroadcastable = Option(true))), + ResolvedHint(testRelation.where('a > 1), isBroadcastable = Option(true)).analyze, caseSensitive = false) } test("should work for subqueries") { checkAnalysis( - Hint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), + ResolvedHint(testRelation, isBroadcastable = Option(true)), caseSensitive = false) checkAnalysis( - Hint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), - BroadcastHint(testRelation), + UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), + ResolvedHint(testRelation, isBroadcastable = Option(true)), caseSensitive = false) // Negative case: if the alias doesn't match, don't match the original table name. checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), + UnresolvedHint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), testRelation, caseSensitive = false) } test("do not traverse past subquery alias") { checkAnalysis( - Hint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), + UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), testRelation.where('a > 1).analyze, caseSensitive = false) } @@ -102,7 +104,8 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - BroadcastHint(testRelation.where('a > 1).select('a)).select('a).analyze, + ResolvedHint(testRelation.where('a > 1).select('a), isBroadcastable = Option(true)) + .select('a).analyze, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 589607e3ad5cb..a0a0daea7d075 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -321,15 +321,14 @@ class ColumnPruningSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input), - BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze + ResolvedHint(SubqueryAlias("y", input)), Inner, None)).analyze val optimized = Optimize.execute(query) val expected = Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input))), + ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Inner, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 950aa2379517e..d4d281e7e05db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -798,12 +798,12 @@ class FilterPushdownSuite extends PlanTest { } test("broadcast hint") { - val originalQuery = BroadcastHint(testRelation) + val originalQuery = ResolvedHint(testRelation) .where('a === 2L && 'b + Rand(10).as("rnd") === 3) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + val correctAnswer = ResolvedHint(testRelation.where('a === 2L)) .where('b + Rand(10).as("rnd") === 3) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index a43d78c7bd447..105407d43bf39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -129,14 +129,14 @@ class JoinOptimizationSuite extends PlanTest { Project(Seq($"x.key", $"y.key"), Join( SubqueryAlias("x", input), - BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze + ResolvedHint(SubqueryAlias("y", input)), Cross, None)).analyze val optimized = Optimize.execute(query) val expected = Join( Project(Seq($"x.key"), SubqueryAlias("x", input)), - BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), + ResolvedHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Cross, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 8bc2010cabece..4d08f016a4a16 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -463,22 +463,30 @@ class ExpressionParserSuite extends PlanTest { assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser) // Escaped characters. - assertEqual("'\0'", "\u0000", parser) // ASCII NUL (X'00') + // Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work + // when ESCAPED_STRING_LITERALS is enabled. + // It is parsed literally. + assertEqual("'\\0'", "\\0", parser) // Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled. val e = intercept[ParseException](parser.parseExpression("'\''")) assert(e.message.contains("extraneous input '''")) - assertEqual("'\"'", "\"", parser) // Double quote - assertEqual("'\b'", "\b", parser) // Backspace - assertEqual("'\n'", "\n", parser) // Newline - assertEqual("'\r'", "\r", parser) // Carriage return - assertEqual("'\t'", "\t", parser) // Tab character - - // Octals - assertEqual("'\110\145\154\154\157\041'", "Hello!", parser) - // Unicode - assertEqual("'\u0057\u006F\u0072\u006C\u0064\u0020\u003A\u0029'", "World :)", parser) + // The unescape special characters (e.g., "\\t") for 2.0+ don't work + // when ESCAPED_STRING_LITERALS is enabled. They are parsed literally. + assertEqual("'\\\"'", "\\\"", parser) // Double quote + assertEqual("'\\b'", "\\b", parser) // Backspace + assertEqual("'\\n'", "\\n", parser) // Newline + assertEqual("'\\r'", "\\r", parser) // Carriage return + assertEqual("'\\t'", "\\t", parser) // Tab character + + // The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser) + // The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled. + // They are parsed literally. + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", + "\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser) } else { // Default behavior diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index d78741d032f38..134e761460881 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -534,30 +534,31 @@ class PlanParserSuite extends PlanTest { comparePlans( parsePlan("SELECT /*+ HINT */ * FROM t"), - Hint("HINT", Seq.empty, table("t").select(star()))) + UnresolvedHint("HINT", Seq.empty, table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), - Hint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) + UnresolvedHint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), - Hint("MAPJOIN", Seq("u"), table("t").select(star()))) + UnresolvedHint("MAPJOIN", Seq("u"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), - Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) + UnresolvedHint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ INDEX(t, emp_job_ix) */ * FROM t"), - Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) + UnresolvedHint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), - Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) + UnresolvedHint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) comparePlans( parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), - Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + UnresolvedHint("MAPJOIN", Seq("t"), + table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index b06871f96f0d8..81b91e63b8f67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -45,7 +45,7 @@ class BasicStatsEstimationSuite extends StatsEstimationTestBase { expectedStatsCboOn = filterStatsCboOn, expectedStatsCboOff = filterStatsCboOff) - val broadcastHint = BroadcastHint(filter) + val broadcastHint = ResolvedHint(filter, isBroadcastable = Option(true)) checkStats( broadcastHint, expectedStatsCboOn = filterStatsCboOn.copy(isBroadcastable = true), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 37e3dfabd0b21..712841835acd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -492,7 +492,8 @@ class TreeNodeSuite extends SparkFunSuite { "tracksPartitionsInCatalog" -> false, "properties" -> JNull, "unsupportedFeatures" -> List.empty[String], - "schemaPreservesCase" -> JBool(true))) + "schemaPreservesCase" -> JBool(true), + "ignoredProperties" -> JNull)) // For unknown case class, returns JNull. val bigValue = new Array[Int](10000) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 53773f18ce553..cbab029b87b2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1174,7 +1174,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { - Hint(name, parameters, logicalPlan) + UnresolvedHint(name, parameters, logicalPlan) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 73541c22c6308..5981b49da277e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -433,7 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil - case BroadcastHint(child) => planLater(child) :: Nil + case h: ResolvedHint => planLater(h.child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5edf03666ac22..563eae0b6483f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.catalyst.plans.logical.ResolvedHint import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.SQLConf @@ -1019,7 +1019,8 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) + Dataset[T](df.sparkSession, + ResolvedHint(df.logicalPlan, isBroadcastable = Option(true)))(df.exprEnc) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index e328b86437d62..a86a86d408906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -174,6 +174,7 @@ object JdbcDialects { registerDialect(MsSqlServerDialect) registerDialect(DerbyDialect) registerDialect(OracleDialect) + registerDialect(TeradataDialect) /** * Fetch the JdbcDialect class corresponding to a given database url. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala new file mode 100644 index 0000000000000..5749b791fca25 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object TeradataDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = { url.startsWith("jdbc:teradata") } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) + case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 26c45e092dc65..afb8ced53e25c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -157,7 +157,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } test("broadcast hint in SQL") { - import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} + import org.apache.spark.sql.catalyst.plans.logical.{ResolvedHint, Join} spark.range(10).createOrReplaceTempView("t") spark.range(10).createOrReplaceTempView("u") @@ -170,12 +170,12 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution .optimizedPlan - assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) - assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) - assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) - assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(plan1.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan1.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(!plan2.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(plan2.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) + assert(!plan3.asInstanceOf[Join].left.isInstanceOf[ResolvedHint]) + assert(!plan3.asInstanceOf[Join].right.isInstanceOf[ResolvedHint]) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d9f3689411ab7..70bee929b31da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -922,6 +922,18 @@ class JDBCSuite extends SparkFunSuite assert(e2.contains("User specified schema not supported with `jdbc`")) } + test("SPARK-15648: teradataDialect StringType data mapping") { + val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + assert(teradataDialect.getJDBCType(StringType). + map(_.databaseTypeDefinition).get == "VARCHAR(255)") + } + + test("SPARK-15648: teradataDialect BooleanType data mapping") { + val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db") + assert(teradataDialect.getJDBCType(BooleanType). + map(_.databaseTypeDefinition).get == "CHAR(1)") + } + test("Checking metrics correctness with JDBC") { val foobarCnt = spark.table("foobar").count() val res = InputOutputMetricsHelper.run(sql("SELECT * FROM foobar").toDF()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 4f090d545cd18..9c60d22d35ce1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -119,20 +119,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { case relation: CatalogRelation if DDLUtils.isHiveTable(relation.tableMeta) && relation.tableMeta.stats.isEmpty => val table = relation.tableMeta - // TODO: check if this estimate is valid for tables after partition pruning. - // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be - // relatively cheap if parameters for the table are populated into the metastore. - // Besides `totalSize`, there are also `numFiles`, `numRows`, `rawDataSize` keys - // (see StatsSetupConst in Hive) that we can look at in the future. - // When table is external,`totalSize` is always zero, which will influence join strategy - // so when `totalSize` is zero, use `rawDataSize` instead. - val totalSize = table.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) - val rawDataSize = table.properties.get(StatsSetupConst.RAW_DATA_SIZE).map(_.toLong) - val sizeInBytes = if (totalSize.isDefined && totalSize.get > 0) { - totalSize.get - } else if (rawDataSize.isDefined && rawDataSize.get > 0) { - rawDataSize.get - } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { + val sizeInBytes = if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { try { val hadoopConf = session.sessionState.newHadoopConf() val tablePath = new Path(table.location) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 04f2751e79a51..b970be740ab51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Order} @@ -414,6 +415,47 @@ private[hive] class HiveClientImpl( val properties = Option(h.getParameters).map(_.asScala.toMap).orNull + // Hive-generated Statistics are also recorded in ignoredProperties + val ignoredProperties = scala.collection.mutable.Map.empty[String, String] + for (key <- HiveStatisticsProperties; value <- properties.get(key)) { + ignoredProperties += key -> value + } + + val excludedTableProperties = HiveStatisticsProperties ++ Set( + // The property value of "comment" is moved to the dedicated field "comment" + "comment", + // For EXTERNAL_TABLE, the table properties has a particular field "EXTERNAL". This is added + // in the function toHiveTable. + "EXTERNAL" + ) + + val filteredProperties = properties.filterNot { + case (key, _) => excludedTableProperties.contains(key) + } + val comment = properties.get("comment") + + val totalSize = properties.get(StatsSetupConst.TOTAL_SIZE).map(BigInt(_)) + val rawDataSize = properties.get(StatsSetupConst.RAW_DATA_SIZE).map(BigInt(_)) + val rowCount = properties.get(StatsSetupConst.ROW_COUNT).map(BigInt(_)).filter(_ >= 0) + // TODO: check if this estimate is valid for tables after partition pruning. + // NOTE: getting `totalSize` directly from params is kind of hacky, but this should be + // relatively cheap if parameters for the table are populated into the metastore. + // Currently, only totalSize, rawDataSize, and rowCount are used to build the field `stats` + // TODO: stats should include all the other two fields (`numFiles` and `numPartitions`). + // (see StatsSetupConst in Hive) + val stats = + // When table is external, `totalSize` is always zero, which will influence join strategy + // so when `totalSize` is zero, use `rawDataSize` instead. When `rawDataSize` is also zero, + // return None. Later, we will use the other ways to estimate the statistics. + if (totalSize.isDefined && totalSize.get > 0L) { + Some(CatalogStatistics(sizeInBytes = totalSize.get, rowCount = rowCount)) + } else if (rawDataSize.isDefined && rawDataSize.get > 0) { + Some(CatalogStatistics(sizeInBytes = rawDataSize.get, rowCount = rowCount)) + } else { + // TODO: still fill the rowCount even if sizeInBytes is empty. Might break anything? + None + } + CatalogTable( identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { @@ -451,13 +493,15 @@ private[hive] class HiveClientImpl( ), // For EXTERNAL_TABLE, the table properties has a particular field "EXTERNAL". This is added // in the function toHiveTable. - properties = properties.filter(kv => kv._1 != "comment" && kv._1 != "EXTERNAL"), - comment = properties.get("comment"), + properties = filteredProperties, + stats = stats, + comment = comment, // In older versions of Spark(before 2.2.0), we expand the view original text and store // that into `viewExpandedText`, and that should be used in view resolution. So we get // `viewExpandedText` instead of `viewOriginalText` for viewText here. viewText = Option(h.getViewExpandedText), - unsupportedFeatures = unsupportedFeatures) + unsupportedFeatures = unsupportedFeatures, + ignoredProperties = ignoredProperties.toMap) } } @@ -474,7 +518,12 @@ private[hive] class HiveClientImpl( } override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState { - val hiveTable = toHiveTable(table, Some(userName)) + // getTableOption removes all the Hive-specific properties. Here, we fill them back to ensure + // these properties are still available to the others that share the same Hive metastore. + // If users explicitly alter these Hive-specific properties through ALTER TABLE DDL, we respect + // these user-specified values. + val hiveTable = toHiveTable( + table.copy(properties = table.ignoredProperties ++ table.properties), Some(userName)) // Do not use `table.qualifiedName` here because this may be a rename val qualifiedTableName = s"${table.database}.$tableName" shim.alterTable(client, qualifiedTableName, hiveTable) @@ -956,4 +1005,14 @@ private[hive] object HiveClientImpl { parameters = if (hp.getParameters() != null) hp.getParameters().asScala.toMap else Map.empty) } + + // Below is the key of table properties for storing Hive-generated statistics + private val HiveStatisticsProperties = Set( + StatsSetupConst.COLUMN_STATS_ACCURATE, + StatsSetupConst.NUM_FILES, + StatsSetupConst.NUM_PARTITIONS, + StatsSetupConst.ROW_COUNT, + StatsSetupConst.RAW_DATA_SIZE, + StatsSetupConst.TOTAL_SIZE + ) } diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 9bf84ab1fb7a2..df7988f542b71 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -19,13 +19,17 @@ package org.apache.spark.sql.hive.test import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.SparkSession import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.client.HiveClient trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive + protected val hiveClient: HiveClient = + spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client protected override def afterAll(): Unit = { try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 7584f1741c62b..d97b11e447fe2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -646,7 +646,6 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle } test("SPARK-15887: hive-site.xml should be loaded") { - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client assert(hiveClient.getConf("hive.in.test", "") == "true") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala index 705d43f1f3aba..3bd3d0d6db355 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala @@ -35,10 +35,6 @@ import org.apache.spark.util.Utils class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach { - // To test `HiveExternalCatalog`, we need to read/write the raw table meta from/to hive client. - val hiveClient: HiveClient = - spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - val tempDir = Utils.createTempDir().getCanonicalFile val tempDirUri = tempDir.toURI val tempDirStr = tempDir.getAbsolutePath diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index b554694815571..c785aca985820 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -52,11 +52,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile } - // To test `HiveExternalCatalog`, we need to read the raw table metadata(schema, partition - // columns and bucket specification are still in table properties) from hive client. - private def hiveClient: HiveClient = - sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client - test("persistent JSON table") { withTable("jsonTable") { sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala index 081153df8e732..fad81c7e9474e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ShowCreateTableSuite.scala @@ -325,26 +325,20 @@ class ShowCreateTableSuite extends QueryTest with SQLTestUtils with TestHiveSing "last_modified_by", "last_modified_time", "Owner:", - "COLUMN_STATS_ACCURATE", // The following are hive specific schema parameters which we do not need to match exactly. - "numFiles", - "numRows", - "rawDataSize", - "totalSize", "totalNumberFiles", "maxFileSize", - "minFileSize", - // EXTERNAL is not non-deterministic, but it is filtered out for external tables. - "EXTERNAL" + "minFileSize" ) table.copy( createTime = 0L, lastAccessTime = 0L, - properties = table.properties.filterKeys(!nondeterministicProps.contains(_)) + properties = table.properties.filterKeys(!nondeterministicProps.contains(_)), + stats = None, + ignoredProperties = Map.empty ) } - assert(normalize(actual) == normalize(expected)) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 3191b9975fbf9..5d52f8baa3b94 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.hive import java.io.{File, PrintWriter} +import org.apache.hadoop.hive.common.StatsSetupConst import scala.reflect.ClassTag +import scala.util.matching.Regex import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.joins._ @@ -61,7 +64,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto val relation = spark.table("csv_table").queryExecution.analyzed.children.head .asInstanceOf[CatalogRelation] - val properties = relation.tableMeta.properties + val properties = relation.tableMeta.ignoredProperties assert(properties("totalSize").toLong <= 0, "external table totalSize must be <= 0") assert(properties("rawDataSize").toLong <= 0, "external table rawDataSize must be <= 0") @@ -175,7 +178,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"INSERT INTO TABLE $textTable SELECT * FROM src") checkTableStats( textTable, - hasSizeInBytes = false, + hasSizeInBytes = true, expectedRowCounts = None) // noscan won't count the number of rows @@ -215,6 +218,210 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + private def createNonPartitionedTable( + tabName: String, + analyzedBySpark: Boolean = true, + analyzedByHive: Boolean = true): Unit = { + sql( + s""" + |CREATE TABLE $tabName (key STRING, value STRING) + |STORED AS TEXTFILE + |TBLPROPERTIES ('prop1' = 'val1', 'prop2' = 'val2') + """.stripMargin) + sql(s"INSERT INTO TABLE $tabName SELECT * FROM src") + if (analyzedBySpark) sql(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") + // This is to mimic the scenario in which Hive genrates statistics before we reading it + if (analyzedByHive) hiveClient.runSqlHive(s"ANALYZE TABLE $tabName COMPUTE STATISTICS") + val describeResult1 = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val tableMetadata = + spark.sessionState.catalog.getTableMetadata(TableIdentifier(tabName)).properties + // statistics info is not contained in the metadata of the original table + assert(Seq(StatsSetupConst.COLUMN_STATS_ACCURATE, + StatsSetupConst.NUM_FILES, + StatsSetupConst.NUM_PARTITIONS, + StatsSetupConst.ROW_COUNT, + StatsSetupConst.RAW_DATA_SIZE, + StatsSetupConst.TOTAL_SIZE).forall(!tableMetadata.contains(_))) + + if (analyzedByHive) { + assert(StringUtils.filterPattern(describeResult1, "*numRows\\s+500*").nonEmpty) + } else { + assert(StringUtils.filterPattern(describeResult1, "*numRows\\s+500*").isEmpty) + } + } + + private def extractStatsPropValues( + descOutput: Seq[String], + propKey: String): Option[BigInt] = { + val str = descOutput + .filterNot(_.contains(HiveExternalCatalog.STATISTICS_PREFIX)) + .filter(_.contains(propKey)) + if (str.isEmpty) { + None + } else { + assert(str.length == 1, "found more than one matches") + val pattern = new Regex(s"""$propKey\\s+(-?\\d+)""") + val pattern(value) = str.head.trim + Option(BigInt(value)) + } + } + + test("get statistics when not analyzed in both Hive and Spark") { + val tabName = "tab1" + withTable(tabName) { + createNonPartitionedTable(tabName, analyzedByHive = false, analyzedBySpark = false) + checkTableStats( + tabName, hasSizeInBytes = true, expectedRowCounts = None) + + // ALTER TABLE SET TBLPROPERTIES invalidates some contents of Hive specific statistics + // This is triggered by the Hive alterTable API + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + val numRows = extractStatsPropValues(describeResult, "numRows") + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(rawDataSize.isEmpty, "rawDataSize should not be shown without table analysis") + assert(numRows.isEmpty, "numRows should not be shown without table analysis") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + } + } + + test("alter table rename after analyze table") { + Seq(true, false).foreach { analyzedBySpark => + val oldName = "tab1" + val newName = "tab2" + withTable(oldName, newName) { + createNonPartitionedTable(oldName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) + val fetchedStats1 = checkTableStats( + oldName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + sql(s"ALTER TABLE $oldName RENAME TO $newName") + val fetchedStats2 = checkTableStats( + newName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + assert(fetchedStats1 == fetchedStats2) + + // ALTER TABLE RENAME does not affect the contents of Hive specific statistics + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $newName") + + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + val numRows = extractStatsPropValues(describeResult, "numRows") + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(rawDataSize.isDefined && rawDataSize.get > 0, "rawDataSize is lost") + assert(numRows.isDefined && numRows.get == 500, "numRows is lost") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + } + } + } + + test("alter table SET TBLPROPERTIES after analyze table") { + Seq(true, false).foreach { analyzedBySpark => + val tabName = "tab1" + withTable(tabName) { + createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) + val fetchedStats1 = checkTableStats( + tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + sql(s"ALTER TABLE $tabName SET TBLPROPERTIES ('foo' = 'a')") + val fetchedStats2 = checkTableStats( + tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + assert(fetchedStats1 == fetchedStats2) + + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + + // ALTER TABLE SET TBLPROPERTIES invalidates some Hive specific statistics + // This is triggered by the Hive alterTable API + val numRows = extractStatsPropValues(describeResult, "numRows") + assert(numRows.isDefined && numRows.get == -1, "numRows is lost") + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") + } + } + } + + test("alter table UNSET TBLPROPERTIES after analyze table") { + Seq(true, false).foreach { analyzedBySpark => + val tabName = "tab1" + withTable(tabName) { + createNonPartitionedTable(tabName, analyzedByHive = true, analyzedBySpark = analyzedBySpark) + val fetchedStats1 = checkTableStats( + tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + sql(s"ALTER TABLE $tabName UNSET TBLPROPERTIES ('prop1')") + val fetchedStats2 = checkTableStats( + tabName, hasSizeInBytes = true, expectedRowCounts = Some(500)) + assert(fetchedStats1 == fetchedStats2) + + val describeResult = hiveClient.runSqlHive(s"DESCRIBE FORMATTED $tabName") + + val totalSize = extractStatsPropValues(describeResult, "totalSize") + assert(totalSize.isDefined && totalSize.get > 0, "totalSize is lost") + + // ALTER TABLE UNSET TBLPROPERTIES invalidates some Hive specific statistics + // This is triggered by the Hive alterTable API + val numRows = extractStatsPropValues(describeResult, "numRows") + assert(numRows.isDefined && numRows.get == -1, "numRows is lost") + val rawDataSize = extractStatsPropValues(describeResult, "rawDataSize") + assert(rawDataSize.isDefined && rawDataSize.get == -1, "rawDataSize is lost") + } + } + } + + test("add/drop partitions - managed table") { + val catalog = spark.sessionState.catalog + val managedTable = "partitionedTable" + withTable(managedTable) { + sql( + s""" + |CREATE TABLE $managedTable (key INT, value STRING) + |PARTITIONED BY (ds STRING, hr STRING) + """.stripMargin) + + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + sql( + s""" + |INSERT OVERWRITE TABLE $managedTable + |partition (ds='$ds',hr='$hr') + |SELECT 1, 'a' + """.stripMargin) + } + + checkTableStats( + managedTable, hasSizeInBytes = false, expectedRowCounts = None) + + sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") + + val stats1 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(4)) + + sql( + s""" + |ALTER TABLE $managedTable DROP PARTITION (ds='2008-04-08'), + |PARTITION (hr='12') + """.stripMargin) + assert(catalog.listPartitions(TableIdentifier(managedTable)).map(_.spec).toSet == + Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + + val stats2 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(4)) + assert(stats1 == stats2) + + sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") + + val stats3 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) + assert(stats2.get.sizeInBytes > stats3.get.sizeInBytes) + + sql(s"ALTER TABLE $managedTable ADD PARTITION (ds='2008-04-08', hr='12')") + sql(s"ANALYZE TABLE $managedTable COMPUTE STATISTICS") + val stats4 = checkTableStats( + managedTable, hasSizeInBytes = true, expectedRowCounts = Some(1)) + + assert(stats2.get.sizeInBytes > stats4.get.sizeInBytes) + assert(stats4.get.sizeInBytes == stats3.get.sizeInBytes) + } + } + test("test statistics of LogicalRelation converted from Hive serde tables") { val parquetTable = "parquetTable" val orcTable = "orcTable" @@ -232,7 +439,8 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } withSQLConf("spark.sql.hive.convertMetastoreOrc" -> "true") { - checkTableStats(orcTable, hasSizeInBytes = false, expectedRowCounts = None) + // We still can get tableSize from Hive before Analyze + checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) } @@ -254,7 +462,6 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto sql(s"analyze table $tableName compute STATISTICS FOR COLUMNS " + stats.keys.mkString(", ")) // Validate statistics - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val table = hiveClient.getTable("default", tableName) val props = table.properties.filterKeys(_.startsWith("spark.sql.statistics.colStats")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 7aff49c0fc3b1..f109843f5be20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -576,7 +576,7 @@ class VersionsSuite extends SparkFunSuite with Logging { versionSpark.sql("CREATE TABLE tbl AS SELECT 1 AS a") assert(versionSpark.table("tbl").collect().toSeq == Seq(Row(1))) val tableMeta = versionSpark.sessionState.catalog.getTableMetadata(TableIdentifier("tbl")) - val totalSize = tableMeta.properties.get(StatsSetupConst.TOTAL_SIZE).map(_.toLong) + val totalSize = tableMeta.stats.map(_.sizeInBytes) // Except 0.12, all the following versions will fill the Hive-generated statistics if (version == "0.12") { assert(totalSize.isEmpty) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index abe5d835719b6..98aa92a9bb88f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -192,12 +192,7 @@ abstract class HiveComparisonTest "last_modified_by", "last_modified_time", "Owner:", - "COLUMN_STATS_ACCURATE", // The following are hive specific schema parameters which we do not need to match exactly. - "numFiles", - "numRows", - "rawDataSize", - "totalSize", "totalNumberFiles", "maxFileSize", "minFileSize" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 13f5c5dd8e80d..9a682260e2bf7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -1199,11 +1199,6 @@ class HiveDDLSuite "last_modified_by", "last_modified_time", "Owner:", - "COLUMN_STATS_ACCURATE", - "numFiles", - "numRows", - "rawDataSize", - "totalSize", "totalNumberFiles", "maxFileSize", "minFileSize" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 5afb37b382e65..97e4c2b6b2db8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.sql.types.StructType * A test suite for Hive view related functionality. */ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { - protected override val spark: SparkSession = TestHive.sparkSession - import testImplicits._ test("create a permanent/temp view using a hive, built-in, and permanent user function") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 6bfb88c0c1af5..52fa401d32c18 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -153,7 +153,6 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA } test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { - val hiveClient = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val location = Utils.createTempDir() val uri = location.toURI try {