diff --git a/.gitignore b/.gitignore index 07e94cddbec5a..8989e66bf3dea 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/NOTICE b/NOTICE index 69b513ea3ba3c..f4b64b5c3f470 100644 --- a/NOTICE +++ b/NOTICE @@ -421,9 +421,6 @@ Copyright (c) 2011, Terrence Parr. This product includes/uses ASM (http://asm.ow2.org/), Copyright (c) 2000-2007 INRIA, France Telecom. -This product includes/uses org.json (http://www.json.org/java/index.html), -Copyright (c) 2002 JSON.org - This product includes/uses JLine (http://jline.sourceforge.net/), Copyright (c) 2002-2006, Marc Prud'hommeaux . diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 0000000000000..bea8f9fbe4eec --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `check-cran.sh` is running `R CMD check`, it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4f..47f9a86dfde11 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,7 +46,7 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR @@ -54,11 +54,11 @@ sc <- sparkR.init(master="local") The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae931..c5f042848c90c 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -36,11 +36,27 @@ if [ ! -z "$R_HOME" ] fi echo "USING R_HOME = $R_HOME" -# Build the latest docs +# Build the latest docs, but not vignettes, which is built with the package next $FWDIR/create-docs.sh -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}"/bin/load-spark-env.sh +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in $SPARK_HOME" + exit 1 +fi # Run check as-cran. VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` @@ -54,11 +70,16 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz - +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c36..84e6aa928cb0f 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -52,21 +52,4 @@ Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knit popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0e..fe41a9e7dabbd 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Version: 2.1.0 +Date: 2016-11-06 Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -18,7 +18,9 @@ Depends: Suggests: testthat, e1071, - survival + survival, + knitr, + rmarkdown Description: The SparkR package provides an R frontend for Apache Spark. License: Apache License (== 2.0) Collate: @@ -48,3 +50,4 @@ Collate: 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9cd6269f9a8f7..daee09de88263 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -45,7 +45,8 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", - "spark.randomForest") + "spark.randomForest", + "spark.gbt") # Job group lifecycle management methods export("setJobGroup", @@ -353,7 +354,9 @@ export("as.DataFrame", "read.ml", "print.summary.KSTest", "print.summary.RandomForestRegressionModel", - "print.summary.RandomForestClassificationModel") + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -380,6 +383,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1df8bbf9fe604..4e3d97bb3ad07 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -788,7 +788,7 @@ setMethod("write.json", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "json", path)) + invisible(handledCallJMethod(write, "json", path)) }) #' Save the contents of SparkDataFrame as an ORC file, preserving the schema. @@ -819,7 +819,7 @@ setMethod("write.orc", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "orc", path)) + invisible(handledCallJMethod(write, "orc", path)) }) #' Save the contents of SparkDataFrame as a Parquet file, preserving the schema. @@ -851,7 +851,7 @@ setMethod("write.parquet", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "parquet", path)) + invisible(handledCallJMethod(write, "parquet", path)) }) #' @rdname write.parquet @@ -895,7 +895,7 @@ setMethod("write.text", function(x, path, mode = "error", ...) { write <- callJMethod(x@sdf, "write") write <- setWriteOptions(write, mode = mode, ...) - invisible(callJMethod(write, "text", path)) + invisible(handledCallJMethod(write, "text", path)) }) #' Distinct @@ -936,7 +936,9 @@ setMethod("unique", #' Sample #' -#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not @@ -3342,7 +3344,7 @@ setMethod("write.jdbc", jprops <- varargsToJProperties(...) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + invisible(handledCallJMethod(write, "jdbc", url, tableName, jprops)) }) #' randomSplit diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 216ca51666ba8..38d83c6e5c52b 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -350,7 +350,7 @@ read.json.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "json", paths) + sdf <- handledCallJMethod(read, "json", paths) dataFrame(sdf) } @@ -422,7 +422,7 @@ read.orc <- function(path, ...) { path <- suppressWarnings(normalizePath(path)) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "orc", path) + sdf <- handledCallJMethod(read, "orc", path) dataFrame(sdf) } @@ -444,7 +444,7 @@ read.parquet.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "parquet", paths) + sdf <- handledCallJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -496,7 +496,7 @@ read.text.default <- function(path, ...) { paths <- as.list(suppressWarnings(normalizePath(path))) read <- callJMethod(sparkSession, "read") read <- callJMethod(read, "options", options) - sdf <- callJMethod(read, "text", paths) + sdf <- handledCallJMethod(read, "text", paths) dataFrame(sdf) } @@ -914,12 +914,13 @@ read.jdbc <- function(url, tableName, } else { numPartitions <- numToInt(numPartitions) } - sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), - numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) } else if (length(predicates) > 0) { - sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), + jprops) } else { - sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + sdf <- handledCallJMethod(read, "jdbc", url, tableName, jprops) } dataFrame(sdf) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4d94b4cd05d44..f8a9d3ce5d918 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1485,7 +1485,7 @@ setMethod("soundex", #' Return the partition ID as a column #' -#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Return the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. #' @@ -2317,7 +2317,8 @@ setMethod("date_format", signature(y = "Column", x = "character"), #' from_utc_timestamp #' -#' Assumes given timestamp is UTC and converts to given timezone. +#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp +#' that corresponds to the same time of day in the given timezone. #' #' @param y Column to compute on. #' @param x time zone to use. @@ -2340,7 +2341,7 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check @@ -2391,7 +2392,8 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' to_utc_timestamp #' -#' Assumes given timestamp is in given timezone and converts to UTC. +#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns +#' another timestamp that corresponds to the same time of day in UTC. #' #' @param y Column to compute on #' @param x timezone to use @@ -2539,7 +2541,7 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the given value numBits right. If the given value is a long value, it will return +#' (Signed) shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @param y column to compute on. @@ -2777,7 +2779,7 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' NOTE: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. @@ -2823,7 +2825,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' rand #' -#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -2852,7 +2855,8 @@ setMethod("rand", signature(seed = "numeric"), #' randn #' -#' Generate a column with i.i.d. samples from the standard normal distribution. +#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' the standard normal distribution. #' #' @param seed a random seed. Can be missing. #' @family normal_funcs @@ -3442,8 +3446,8 @@ setMethod("size", #' sort_array #' -#' Sorts the input array for the given column in ascending order, -#' according to the natural ordering of the array elements. +#' Sorts the input array in ascending or descending order according +#' to the natural ordering of the array elements. #' #' @param x A Column to sort #' @param asc A logical flag indicating the sorting order. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0271b26a10a90..499c7b279ea9d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1343,6 +1343,10 @@ setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + #' @rdname spark.glm #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) @@ -1369,7 +1373,7 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname spark.mlp #' @export -setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) +setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes #' @export diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7a220b8d53a2f..265e64e7466fa 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -116,6 +116,20 @@ setClass("RandomForestRegressionModel", representation(jobj = "jobj")) #' @note RandomForestClassificationModel since 2.1.0 setClass("RandomForestClassificationModel", representation(jobj = "jobj")) +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + #' Saves the MLlib model to the input path #' #' Saves the MLlib model to the input path. For more information, see the specific @@ -124,7 +138,8 @@ setClass("RandomForestClassificationModel", representation(jobj = "jobj")) #' @name write.ml #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg}, #' @seealso \link{read.ml} @@ -138,7 +153,8 @@ NULL #' @name predict #' @export #' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, +#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.gbt}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, #' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, #' @seealso \link{spark.randomForest}, \link{spark.survreg} NULL @@ -509,7 +525,7 @@ setMethod("write.ml", signature(object = "LDAModel", path = "character"), #' @note spark.isoreg since 2.1.0 setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { - formula <- paste0(deparse(formula), collapse = "") + formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { weightCol <- "" @@ -634,7 +650,7 @@ setMethod("fitted", signature(object = "KMeansModel"), # Get the summary of a k-means model #' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. +#' @return \code{summary} returns the model's features, coefficients, k, size and cluster. #' @rdname spark.kmeans #' @export #' @note summary(KMeansModel) since 2.0.0 @@ -679,15 +695,15 @@ setMethod("predict", signature(object = "KMeansModel"), #' @param data SparkDataFrame for training #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. +#' @param regParam the regularization parameter. #' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. #' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination #' of L1 and L2. Default is 0.0 which is an L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. +#' @param fitIntercept whether to fit an intercept term. #' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". +#' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". @@ -705,11 +721,11 @@ setMethod("predict", signature(object = "KMeansModel"), #' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of #' predicting each class. Array must have length equal to the number of classes, with values > 0, #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. +#' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. #' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". +#' are large, this param could be adjusted to a larger size. +#' @param probabilityCol column name for predicted class conditional probabilities. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model #' @rdname spark.logit @@ -759,7 +775,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, probabilityCol = "probability") { - formula <- paste0(deparse(formula), collapse = "") + formula <- paste(deparse(formula), collapse = "") if (is.null(weightCol)) { weightCol <- "" @@ -791,8 +807,10 @@ setMethod("predict", signature(object = "LogisticRegressionModel"), # Get the summary of an LogisticRegressionModel #' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. +#' @return \code{summary} returns the Binary Logistic regression results of a given model as list, +#' including roc, areaUnderROC, pr, fMeasureByThreshold, precisionByThreshold, +#' recallByThreshold, totalIterations, objectiveHistory. Note that Multinomial logistic +#' regression summary is not available now. #' @rdname spark.logit #' @aliases summary,LogisticRegressionModel-method #' @export @@ -840,6 +858,8 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' Multilayer Perceptron} #' #' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param blockSize blockSize parameter. #' @param layers integer vector containing the number of nodes for each layer #' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". @@ -852,7 +872,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp -#' @aliases spark.mlp,SparkDataFrame-method +#' @aliases spark.mlp,SparkDataFrame,formula-method #' @name spark.mlp #' @seealso \link{read.ml} #' @export @@ -861,7 +881,7 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") #' #' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", #' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, #' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) #' @@ -878,9 +898,10 @@ setMethod("summary", signature(object = "LogisticRegressionModel"), #' summary(savedModel) #' } #' @note spark.mlp since 2.1.0 -setMethod("spark.mlp", signature(data = "SparkDataFrame"), - function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, +setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { stop ("layers must be a integer vector with length > 1.") } @@ -895,7 +916,7 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame"), initialWeights <- as.array(as.numeric(na.omit(initialWeights))) } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", - "fit", data@sdf, as.integer(blockSize), as.array(layers), + "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), as.numeric(stepSize), seed, initialWeights) new("MultilayerPerceptronClassificationModel", jobj = jobj) @@ -918,9 +939,10 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel # Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} #' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns a list containing \code{labelCount}, \code{layers}, and -#' \code{weights}. For \code{weights}, it is a numeric vector with length equal to -#' the expected given the architecture (i.e., for 8-10-2 network, 100 connection weights). +#' @return \code{summary} returns a list containing \code{numOfInputs}, \code{numOfOutputs}, +#' \code{layers}, and \code{weights}. For \code{weights}, it is a numeric vector with +#' length equal to the expected given the architecture (i.e., for 8-10-2 network, +#' 112 connection weights). #' @rdname spark.mlp #' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method @@ -928,10 +950,12 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), function(object) { jobj <- object@jobj - labelCount <- callJMethod(jobj, "labelCount") layers <- unlist(callJMethod(jobj, "layers")) + numOfInputs <- head(layers, n = 1) + numOfOutputs <- tail(layers, n = 1) weights <- callJMethod(jobj, "weights") - list(labelCount = labelCount, layers = layers, weights = weights) + list(numOfInputs = numOfInputs, numOfOutputs = numOfOutputs, + layers = layers, weights = weights) }) #' Naive Bayes Models @@ -1141,6 +1165,10 @@ read.ml <- function(path) { new("RandomForestRegressionModel", jobj = jobj) } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) } else { stop("Unsupported model: ", jobj) } @@ -1196,13 +1224,13 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula #' data and \code{write.ml}/\code{read.ml} to save/load fitted models. #' #' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". #' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 +#' each iteration of mini-batch gradient descent, in range (0, 1]. #' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for #' the prior placed on topic distributions over terms, default -1 to set automatically on the #' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size @@ -1263,7 +1291,7 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), # similarly to R's summary(). #' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, +#' @return \code{summary} returns a list containing the model's features, coefficients, #' intercept and log(scale) #' @rdname spark.survreg #' @export @@ -1351,7 +1379,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = # Get the summary of a multivariate gaussian mixture model #' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. +#' @return \code{summary} returns the model's lambda, mu, sigma, k, dim and posterior. #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method #' @rdname spark.gaussianMixture #' @export @@ -1644,33 +1672,38 @@ print.summary.KSTest <- function(x, ...) { #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) +#' @param maxDepth Maximum depth of the tree (>= 0). #' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing #' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. (default = 32) +#' >= 2 and >= number of categories in any categorical feature. #' @param numTrees Number of trees to train (>= 1). #' @param impurity Criterion used for information gain calculation. #' For regression, must be "variance". For classification, must be one of -#' "entropy" and "gini". (default = gini) -#' @param minInstancesPerNode Minimum number of instances each child must have after split. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' "entropy" and "gini", default is "gini". #' @param featureSubsetStrategy The number of features to consider for splits at each tree node. #' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. #' @param seed integer seed for random number generation. #' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. (default = 1.0) -#' @param probabilityCol column name for predicted class conditional probabilities, only for -#' classification. (default = "probability") +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). #' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. #' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param probabilityCol column name for predicted class conditional probabilities, only for +#' classification. #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -1703,9 +1736,9 @@ print.summary.KSTest <- function(x, ...) { setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, type = c("regression", "classification"), maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, - minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, - probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE, probabilityCol = "probability") { type <- match.arg(type) formula <- paste(deparse(formula), collapse = "") if (!is.null(seed)) { @@ -1749,7 +1782,7 @@ setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "fo #' @rdname spark.randomForest #' @aliases predict,RandomForestRegressionModel-method #' @export -#' @note predict(randomForestRegressionModel) since 2.1.0 +#' @note predict(RandomForestRegressionModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestRegressionModel"), function(object, newData) { predict_internal(object, newData) @@ -1758,7 +1791,7 @@ setMethod("predict", signature(object = "RandomForestRegressionModel"), #' @rdname spark.randomForest #' @aliases predict,RandomForestClassificationModel-method #' @export -#' @note predict(randomForestClassificationModel) since 2.1.0 +#' @note predict(RandomForestClassificationModel) since 2.1.0 setMethod("predict", signature(object = "RandomForestClassificationModel"), function(object, newData) { predict_internal(object, newData) @@ -1789,8 +1822,8 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path write_internal(object, path, overwrite) }) -# Get the summary of an RandomForestRegressionModel model -summary.randomForest <- function(model) { +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { jobj <- model@jobj formula <- callJMethod(jobj, "formula") numFeatures <- callJMethod(jobj, "numFeatures") @@ -1807,20 +1840,23 @@ summary.randomForest <- function(model) { jobj = jobj) } -#' @return \code{summary} returns the model's features as lists, depth and number of nodes -#' or number of classes. +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights #' @rdname spark.randomForest #' @aliases summary,RandomForestRegressionModel-method #' @export #' @note summary(RandomForestRegressionModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestRegressionModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestRegressionModel" ans }) -# Get the summary of an RandomForestClassificationModel model +# Get the summary of a Random Forest Classification Model #' @rdname spark.randomForest #' @aliases summary,RandomForestClassificationModel-method @@ -1828,13 +1864,13 @@ setMethod("summary", signature(object = "RandomForestRegressionModel"), #' @note summary(RandomForestClassificationModel) since 2.1.0 setMethod("summary", signature(object = "RandomForestClassificationModel"), function(object) { - ans <- summary.randomForest(object) + ans <- summary.treeEnsemble(object) class(ans) <- "summary.RandomForestClassificationModel" ans }) -# Prints the summary of Random Forest Regression Model -print.summary.randomForest <- function(x) { +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { jobj <- x$jobj cat("Formula: ", x$formula) cat("\nNumber of features: ", x$numFeatures) @@ -1848,13 +1884,15 @@ print.summary.randomForest <- function(x) { invisible(x) } +# Prints the summary of Random Forest Regression Model + #' @param x summary object of Random Forest regression model or classification model #' returned by \code{summary}. #' @rdname spark.randomForest #' @export #' @note print.summary.RandomForestRegressionModel since 2.1.0 print.summary.RandomForestRegressionModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) } # Prints the summary of Random Forest Classification Model @@ -1863,5 +1901,214 @@ print.summary.RandomForestRegressionModel <- function(x, ...) { #' @export #' @note print.summary.RandomForestClassificationModel since 2.1.0 print.summary.RandomForestClassificationModel <- function(x, ...) { - print.summary.randomForest(x) + print.summary.treeEnsemble(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' df <- createDataFrame(iris[iris$Species != "virginica", ]) +#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' +#' # numeric label is also supported +#' iris2 <- iris[iris$Species != "virginica", ] +#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) +#' df <- createDataFrame(iris2) +#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction" +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model +#' @param path The directory where the model is saved +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns a summary object of the fitted model, a list of components +#' including formula, number of features, list of features, feature importances, number of +#' trees, and tree weights +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) } diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index c4e78cbb804d9..20004549cc037 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -338,21 +338,41 @@ varargsToEnv <- function(...) { # into string. varargsToStrEnv <- function(...) { pairs <- list(...) + nameList <- names(pairs) env <- new.env() - for (name in names(pairs)) { - value <- pairs[[name]] - if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { - stop(paste0("Unsupported type for ", name, " : ", class(value), - ". Supported types are logical, numeric, character and NULL.")) - } - if (is.logical(value)) { - env[[name]] <- tolower(as.character(value)) - } else if (is.null(value)) { - env[[name]] <- value - } else { - env[[name]] <- as.character(value) + ignoredNames <- list() + + if (is.null(nameList)) { + # When all arguments are not named, names(..) returns NULL. + ignoredNames <- pairs + } else { + for (i in seq_along(pairs)) { + name <- nameList[i] + value <- pairs[i] + if (identical(name, "")) { + # When some of arguments are not named, name is "". + ignoredNames <- append(ignoredNames, value) + } else { + value <- pairs[[name]] + if (!(is.logical(value) || is.numeric(value) || is.character(value) || is.null(value))) { + stop(paste0("Unsupported type for ", name, " : ", class(value), + ". Supported types are logical, numeric, character and NULL."), call. = FALSE) + } + if (is.logical(value)) { + env[[name]] <- tolower(as.character(value)) + } else if (is.null(value)) { + env[[name]] <- value + } else { + env[[name]] <- as.character(value) + } + } } } + + if (length(ignoredNames) != 0) { + warning(paste0("Unnamed arguments ignored: ", paste(ignoredNames, collapse = ", "), "."), + call. = FALSE) + } env } diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index db98d0e45547e..2a97a51cfa205 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -64,6 +64,16 @@ test_that("spark.glm and predict", { rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # binomial family + binomialTraining <- training[training$Species %in% c("versicolor", "virginica"), ] + model <- spark.glm(binomialTraining, Species ~ Sepal_Length + Sepal_Width, + family = binomial(link = "logit")) + prediction <- predict(model, binomialTraining) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("virginica", "virginica", "virginica", "versicolor", "virginica", + "versicolor", "virginica", "versicolor", "virginica", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) + # poisson family model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, family = poisson(link = identity)) @@ -128,12 +138,12 @@ test_that("spark.glm summary", { expect_equal(stats$aic, rStats$aic) # Test spark.glm works with weighted dataset - a1 <- c(0, 1, 2, 3) - a2 <- c(5, 2, 1, 3) - w <- c(1, 2, 3, 4) - b <- c(1, 0, 1, 0) + a1 <- c(0, 1, 2, 3, 4) + a2 <- c(5, 2, 1, 3, 2) + w <- c(1, 2, 3, 4, 5) + b <- c(1, 0, 1, 0, 0) data <- as.data.frame(cbind(a1, a2, w, b)) - df <- suppressWarnings(createDataFrame(data)) + df <- createDataFrame(data) stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) @@ -158,7 +168,7 @@ test_that("spark.glm summary", { data <- as.data.frame(cbind(a1, a2, b)) df <- suppressWarnings(createDataFrame(data)) regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) - expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + expect_equal(regStats$aic, 14.00976, tolerance = 1e-4) # 14.00976 is from summary() result }) test_that("spark.glm save/load", { @@ -361,12 +371,13 @@ test_that("spark.kmeans", { test_that("spark.mlp", { df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), source = "libsvm") - model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100, - tol = 0.5, stepSize = 1, seed = 1) + model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), + solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) # Test summary method summary <- summary(model) - expect_equal(summary$labelCount, 3) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 5, 4, 3)) expect_equal(length(summary$weights), 64) expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), @@ -375,7 +386,7 @@ test_that("spark.mlp", { # Test predict method mlpTestDF <- df mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1)) + expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) # Test model save/load modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") @@ -385,46 +396,68 @@ test_that("spark.mlp", { model2 <- read.ml(modelPath) summary2 <- summary(model2) - expect_equal(summary2$labelCount, 3) + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) expect_equal(summary2$layers, c(4, 5, 4, 3)) expect_equal(length(summary2$weights), 64) unlink(modelPath) # Test default parameter - model <- spark.mlp(df, layers = c(4, 5, 4, 3)) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 0)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test illegal parameter - expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = NULL), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c()), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c(3)), + "layers must be a integer vector with length > 1.") # Test random seed # default seed - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # seed equals 10 - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2) + model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1)) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + + # Test formula works well + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.mlp(df, Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + layers = c(4, 3)) + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 3)) + expect_equal(length(summary$weights), 15) + expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, + -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { @@ -575,7 +608,7 @@ test_that("spark.isotonicRegression", { feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) data <- as.data.frame(cbind(label, feature, weight)) - df <- suppressWarnings(createDataFrame(data)) + df <- createDataFrame(data) model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, weightCol = "weight") @@ -871,7 +904,8 @@ test_that("spark.kstest", { expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") }) -test_that("spark.randomForest Regression", { +test_that("spark.randomForest", { + # regression data <- suppressWarnings(createDataFrame(longley)) model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, numTrees = 1) @@ -913,9 +947,8 @@ test_that("spark.randomForest Regression", { expect_equal(stats$treeWeights, stats2$treeWeights) unlink(modelPath) -}) -test_that("spark.randomForest Classification", { + # classification data <- suppressWarnings(createDataFrame(iris)) model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", maxDepth = 5, maxBins = 16) @@ -925,6 +958,10 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numTrees, 20) expect_error(capture.output(stats), NA) expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") write.ml(model, modelPath) @@ -937,6 +974,106 @@ test_that("spark.randomForest Classification", { expect_equal(stats$numClasses, stats2$numClasses) unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) }) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 9289db57b6d63..ee48baa59c7af 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1222,16 +1222,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() @@ -2659,7 +2659,15 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume # It makes sure that we can omit path argument in write.df API and then it calls # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), - "Error in save : illegal argument - 'path' is not specified") + "Error in save : illegal argument - Expected exactly one path to be specified") + expect_error(write.json(df, jsonPath), + "Error in json : analysis error - path file:.*already exists") + expect_error(write.text(df, jsonPath), + "Error in text : analysis error - path file:.*already exists") + expect_error(write.orc(df, jsonPath), + "Error in orc : analysis error - path file:.*already exists") + expect_error(write.parquet(df, jsonPath), + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), @@ -2679,6 +2687,11 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") + expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") + expect_error(read.text("arbitrary_path"), "Error in text : analysis error - Path does not exist") + expect_error(read.orc("arbitrary_path"), "Error in orc : analysis error - Path does not exist") + expect_error(read.parquet("arbitrary_path"), + "Error in parquet : analysis error - Path does not exist") # Arguments checking in R side. expect_error(read.df(path = c(3)), @@ -2686,6 +2699,9 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) + + expect_warning(read.json(jsonPath, a = 1, 2, 3, "a"), + "Unnamed arguments ignored: 2, 3, a.") }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index a20254e9b3fa9..607c407f04f97 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -224,6 +224,8 @@ test_that("varargsToStrEnv", { expect_error(varargsToStrEnv(a = list(1, "a")), paste0("Unsupported type for a : list. Supported types are logical, ", "numeric, character and NULL.")) + expect_warning(varargsToStrEnv(a = 1, 2, 3, 4), "Unnamed arguments ignored: 2, 3, 4.") + expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) sparkR.session.stop() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bddb..73a5e26a3ba9c 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,12 +1,13 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- ## Overview diff --git a/bin/beeline b/bin/beeline index 1627626941a73..058534699e44b 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 0000000000000..fa78407d4175a --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3d..8a2f709960a25 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a44321..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c4120260..4ba5399311d33 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f6..77ea40cc37946 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880ee..421f36cac3d47 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51dd..b08b944ebd319 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b8..4e9d3614e6370 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173b..29ab10df8ab6d 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c91..ca99fa89ebe1b 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -76,6 +76,10 @@ guava compile + + org.apache.commons + commons-crypto + diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 9e5c616ee5a1f..a1bb453657460 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -30,6 +30,8 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; @@ -88,9 +90,26 @@ public void doBootstrap(TransportClient client, Channel channel) { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } - SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + + if (conf.aesEncryptionEnabled()) { + // Generate a request config message to send to server. + AesConfigMessage configMessage = AesCipher.createConfigMessage(conf); + ByteBuffer buf = configMessage.encodeMessage(); + + // Encrypted the config message. + byte[] toEncrypt = JavaUtils.bufferToArray(buf); + ByteBuffer encrypted = ByteBuffer.wrap(saslClient.wrap(toEncrypt, 0, toEncrypt.length)); + + client.sendRpcSync(encrypted, conf.saslRTTimeoutMs()); + AesCipher cipher = new AesCipher(configMessage, conf); + logger.info("Enabling AES cipher for client channel {}", client); + cipher.addToChannel(channel); + saslClient.dispose(); + } else { + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); + } saslClient = null; - logger.debug("Channel {} configured for SASL encryption.", client); + logger.debug("Channel {} configured for encryption.", client); } } catch (IOException ioe) { throw new RuntimeException(ioe); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c41f5b6873f6c..b2f3ef214b7ac 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -29,6 +29,8 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.aes.AesCipher; +import org.apache.spark.network.sasl.aes.AesConfigMessage; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.util.JavaUtils; @@ -59,6 +61,7 @@ class SaslRpcHandler extends RpcHandler { private SparkSaslServer saslServer; private boolean isComplete; + private boolean isAuthenticated; SaslRpcHandler( TransportConf conf, @@ -71,6 +74,7 @@ class SaslRpcHandler extends RpcHandler { this.secretKeyHolder = secretKeyHolder; this.saslServer = null; this.isComplete = false; + this.isAuthenticated = false; } @Override @@ -80,30 +84,31 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb delegate.receive(client, message, callback); return; } + if (saslServer == null || !saslServer.isComplete()) { + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - SaslMessage saslMessage; - try { - saslMessage = SaslMessage.decode(nettyBuf); - } finally { - nettyBuf.release(); - } - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - client.setClientId(saslMessage.appId); - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); - } + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); + } - byte[] response; - try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); - } catch (IOException ioe) { - throw new RuntimeException(ioe); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); } - callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -111,15 +116,42 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - isComplete = true; - if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("SASL authentication successful for channel {}", client); + complete(true); + return; + } + + if (!conf.aesEncryptionEnabled()) { logger.debug("Enabling encryption for channel {}", client); SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - saslServer = null; - } else { - saslServer.dispose(); - saslServer = null; + complete(false); + return; + } + + // Extra negotiation should happen after authentication, so return directly while + // processing authenticate. + if (!isAuthenticated) { + logger.debug("SASL authentication successful for channel {}", client); + isAuthenticated = true; + return; + } + + // Create AES cipher when it is authenticated + try { + byte[] encrypted = JavaUtils.bufferToArray(message); + ByteBuffer decrypted = ByteBuffer.wrap(saslServer.unwrap(encrypted, 0 , encrypted.length)); + + AesConfigMessage configMessage = AesConfigMessage.decodeMessage(decrypted); + AesCipher cipher = new AesCipher(configMessage, conf); + + // Send response back to client to confirm that server accept config. + callback.onSuccess(JavaUtils.stringToBytes(AesCipher.TRANSFORM)); + logger.info("Enabling AES cipher for Server channel {}", client); + cipher.addToChannel(channel); + complete(true); + } catch (IOException ioe) { + throw new RuntimeException(ioe); } } } @@ -155,4 +187,17 @@ public void exceptionCaught(Throwable cause, TransportClient client) { delegate.exceptionCaught(cause, client); } + private void complete(boolean dispose) { + if (dispose) { + try { + saslServer.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL server", e); + } + } + + saslServer = null; + isComplete = true; + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java new file mode 100644 index 0000000000000..78034a69f734d --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesCipher.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.AbstractReferenceCounted; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * AES cipher for encryption and decryption. + */ +public class AesCipher { + private static final Logger logger = LoggerFactory.getLogger(AesCipher.class); + public static final String ENCRYPTION_HANDLER_NAME = "AesEncryption"; + public static final String DECRYPTION_HANDLER_NAME = "AesDecryption"; + public static final int STREAM_BUFFER_SIZE = 1024 * 32; + public static final String TRANSFORM = "AES/CTR/NoPadding"; + + private final SecretKeySpec inKeySpec; + private final IvParameterSpec inIvSpec; + private final SecretKeySpec outKeySpec; + private final IvParameterSpec outIvSpec; + private final Properties properties; + + public AesCipher(AesConfigMessage configMessage, TransportConf conf) throws IOException { + this.properties = CryptoStreamUtils.toCryptoConf(conf); + this.inKeySpec = new SecretKeySpec(configMessage.inKey, "AES"); + this.inIvSpec = new IvParameterSpec(configMessage.inIv); + this.outKeySpec = new SecretKeySpec(configMessage.outKey, "AES"); + this.outIvSpec = new IvParameterSpec(configMessage.outIv); + } + + /** + * Create AES crypto output stream + * @param ch The underlying channel to write out. + * @return Return output crypto stream for encryption. + * @throws IOException + */ + private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(TRANSFORM, properties, ch, outKeySpec, outIvSpec); + } + + /** + * Create AES crypto input stream + * @param ch The underlying channel used to read data. + * @return Return input crypto stream for decryption. + * @throws IOException + */ + private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(TRANSFORM, properties, ch, inKeySpec, inIvSpec); + } + + /** + * Add handlers to channel + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new AesEncryptHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new AesDecryptHandler(this)); + } + + /** + * Create the configuration message + * @param conf is the local transport configuration. + * @return Config message for sending. + */ + public static AesConfigMessage createConfigMessage(TransportConf conf) { + int keySize = conf.aesCipherKeySize(); + Properties properties = CryptoStreamUtils.toCryptoConf(conf); + + try { + int paramLen = CryptoCipherFactory.getCryptoCipher(AesCipher.TRANSFORM, properties) + .getBlockSize(); + byte[] inKey = new byte[keySize]; + byte[] outKey = new byte[keySize]; + byte[] inIv = new byte[paramLen]; + byte[] outIv = new byte[paramLen]; + + CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties); + random.nextBytes(inKey); + random.nextBytes(outKey); + random.nextBytes(inIv); + random.nextBytes(outIv); + + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } catch (Exception e) { + logger.error("AES config error", e); + throw Throwables.propagate(e); + } + } + + /** + * CryptoStreamUtils is used to convert config from TransportConf to AES Crypto config. + */ + private static class CryptoStreamUtils { + public static Properties toCryptoConf(TransportConf conf) { + Properties props = new Properties(); + if (conf.aesCipherClass() != null) { + props.setProperty(CryptoCipherFactory.CLASSES_KEY, conf.aesCipherClass()); + } + return props; + } + } + + private static class AesEncryptHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteChannel; + private final CryptoOutputStream cos; + + AesEncryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteChannel); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + cos.close(); + } finally { + super.close(ctx, promise); + } + } + } + + private static class AesDecryptHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + + AesDecryptHandler(AesCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + byteChannel.feedData((ByteBuf) data); + + byte[] decryptedData = new byte[byteChannel.readableBytes()]; + int offset = 0; + while (offset < decryptedData.length) { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + cis.close(); + } finally { + super.channelInactive(ctx); + } + } + } + + private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private long transferred; + private CryptoOutputStream cos; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private ByteArrayWritableChannel byteEncChannel; + private ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.byteRawChannel = new ByteArrayWritableChannel(AesCipher.STREAM_BUFFER_SIZE); + this.cos = cos; + this.byteEncChannel = ch; + } + + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return transferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + do { + if (currentEncrypted == null) { + encryptMore(); + } + + int bytesWritten = currentEncrypted.remaining(); + target.write(currentEncrypted); + bytesWritten -= currentEncrypted.remaining(); + transferred += bytesWritten; + if (!currentEncrypted.hasRemaining()) { + currentEncrypted = null; + byteEncChannel.reset(); + } + } while (transferred < count()); + + return transferred; + } + + private void encryptMore() throws IOException { + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transfered()); + } + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java new file mode 100644 index 0000000000000..3ef6f74a1f89f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/aes/AesConfigMessage.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl.aes; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The AES cipher options for encryption negotiation. + */ +public class AesConfigMessage implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEB; + + public byte[] inKey; + public byte[] outKey; + public byte[] inIv; + public byte[] outIv; + + public AesConfigMessage(byte[] inKey, byte[] inIv, byte[] outKey, byte[] outIv) { + if (inKey == null || inIv == null || outKey == null || outIv == null) { + throw new IllegalArgumentException("Cipher Key or IV must not be null!"); + } + + this.inKey = inKey; + this.inIv = inIv; + this.outKey = outKey; + this.outIv = outIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(inKey) + Encoders.ByteArrays.encodedLength(outKey) + + Encoders.ByteArrays.encodedLength(inIv) + Encoders.ByteArrays.encodedLength(outIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, inKey); + Encoders.ByteArrays.encode(buf, inIv); + Encoders.ByteArrays.encode(buf, outKey); + Encoders.ByteArrays.encode(buf, outIv); + } + + /** + * Encode the config message. + * @return ByteBuffer which contains encoded config message. + */ + public ByteBuffer encodeMessage(){ + ByteBuffer buf = ByteBuffer.allocate(encodedLength()); + + ByteBuf wrappedBuf = Unpooled.wrappedBuffer(buf); + wrappedBuf.clear(); + encode(wrappedBuf); + + return buf; + } + + /** + * Decode the config message from buffer + * @param buffer the buffer contain encoded config message + * @return config message + */ + public static AesConfigMessage decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected AesConfigMessage, received something else" + + " (maybe your client does not have AES enabled?)"); + } + + byte[] outKey = Encoders.ByteArrays.decode(buf); + byte[] outIv = Encoders.ByteArrays.decode(buf); + byte[] inKey = Encoders.ByteArrays.decode(buf); + byte[] inIv = Encoders.ByteArrays.decode(buf); + return new AesConfigMessage(inKey, inIv, outKey, outIv); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java new file mode 100644 index 0000000000000..25d103d0e316f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import io.netty.buffer.ByteBuf; + +public class ByteArrayReadableChannel implements ReadableByteChannel { + private ByteBuf data; + + public int readableBytes() { + return data.readableBytes(); + } + + public void feedData(ByteBuf buf) { + data = buf; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int totalRead = 0; + while (data.readableBytes() > 0 && dst.remaining() > 0) { + int bytesToRead = Math.min(data.readableBytes(), dst.remaining()); + dst.put(data.readSlice(bytesToRead).nioBuffer()); + totalRead += bytesToRead; + } + + if (data.readableBytes() == 0) { + data.release(); + } + + return totalRead; + } + + @Override + public void close() throws IOException { + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 64eaba103cccb..012bb098f6fc4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -175,4 +175,25 @@ public boolean saslServerAlwaysEncrypt() { return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } + /** + * The trigger for enabling AES encryption. + */ + public boolean aesEncryptionEnabled() { + return conf.getBoolean("spark.authenticate.encryption.aes.enabled", false); + } + + /** + * The implementation class for crypto cipher + */ + public String aesCipherClass() { + return conf.get("spark.authenticate.encryption.aes.cipher.class", null); + } + + /** + * The bytes of AES cipher key which is effective when AES cipher is enabled. Notice that + * the length should be 16, 24 or 32 bytes. + */ + public int aesCipherKeySize() { + return conf.getInt("spark.authenticate.encryption.aes.cipher.keySize", 16); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 45cc03df435ac..ef2ab34b2277c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -53,6 +53,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.aes.AesCipher; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; @@ -149,7 +150,7 @@ public Void answer(InvocationOnMock invocation) { .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); - SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false); + SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false, false); try { ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); @@ -275,7 +276,7 @@ public ManagedBuffer answer(InvocationOnMock invocation) { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false); + ctx = new SaslTestCtx(rpcHandler, true, false, false); final CountDownLatch lock = new CountDownLatch(1); @@ -317,7 +318,7 @@ public void testServerAlwaysEncrypt() throws Exception { SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, false); fail("Should have failed to connect without encryption."); } catch (Exception e) { assertTrue(e.getCause() instanceof SaslException); @@ -336,7 +337,7 @@ public void testDataEncryptionIsActuallyEnabled() throws Exception { // able to understand RPCs sent to it and thus close the connection. SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), true, true); + ctx = new SaslTestCtx(mock(RpcHandler.class), true, true, false); ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), TimeUnit.SECONDS.toMillis(10)); fail("Should have failed to send RPC to server."); @@ -374,6 +375,69 @@ public void testDelegates() throws Exception { } } + @Test + public void testAesEncryption() throws Exception { + final AtomicReference response = new AtomicReference<>(); + final File file = File.createTempFile("sasltest", ".txt"); + SaslTestCtx ctx = null; + try { + final TransportConf conf = new TransportConf("rpc", new SystemPropertyConfigProvider()); + final TransportConf spyConf = spy(conf); + doReturn(true).when(spyConf).aesEncryptionEnabled(); + + StreamManager sm = mock(StreamManager.class); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { + @Override + public ManagedBuffer answer(InvocationOnMock invocation) { + return new FileSegmentManagedBuffer(spyConf, file, 0, file.length()); + } + }); + + RpcHandler rpcHandler = mock(RpcHandler.class); + when(rpcHandler.getStreamManager()).thenReturn(sm); + + byte[] data = new byte[256 * 1024 * 1024]; + new Random().nextBytes(data); + Files.write(data, file); + + ctx = new SaslTestCtx(rpcHandler, true, false, true); + + final Object lock = new Object(); + + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + synchronized (lock) { + lock.notifyAll(); + } + return null; + } + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + + synchronized (lock) { + ctx.client.fetchChunk(0, 0, callback); + lock.wait(10 * 1000); + } + + verify(callback, times(1)).onSuccess(anyInt(), any(ManagedBuffer.class)); + verify(callback, never()).onFailure(anyInt(), any(Throwable.class)); + + byte[] received = ByteStreams.toByteArray(response.get().createInputStream()); + assertTrue(Arrays.equals(data, received)); + } finally { + file.delete(); + if (ctx != null) { + ctx.close(); + } + if (response.get() != null) { + response.get().release(); + } + } + } + private static class SaslTestCtx { final TransportClient client; @@ -386,18 +450,28 @@ private static class SaslTestCtx { SaslTestCtx( RpcHandler rpcHandler, boolean encrypt, - boolean disableClientEncryption) + boolean disableClientEncryption, + boolean aesEnable) throws Exception { TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + if (aesEnable) { + conf = spy(conf); + doReturn(true).when(conf).aesEncryptionEnabled(); + } + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); when(keyHolder.getSecretKey(anyString())).thenReturn("secret"); TransportContext ctx = new TransportContext(conf, rpcHandler); - this.checker = new EncryptionCheckerBootstrap(); + String encryptHandlerName = aesEnable ? AesCipher.ENCRYPTION_HANDLER_NAME : + SaslEncryption.ENCRYPTION_HANDLER_NAME; + + this.checker = new EncryptionCheckerBootstrap(encryptHandlerName); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); @@ -437,13 +511,18 @@ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAd implements TransportServerBootstrap { boolean foundEncryptionHandler; + String encryptHandlerName; + + EncryptionCheckerBootstrap(String encryptHandlerName) { + this.encryptHandlerName = encryptHandlerName; + } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (!foundEncryptionHandler) { foundEncryptionHandler = - ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + ctx.channel().pipeline().get(encryptHandlerName) != null; } ctx.write(msg, promise); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a753..a7b0e6f80c2b6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -252,6 +252,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public final long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..ea5f1a9abf69b 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55b..3dd318471008b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafad..252a35ec6bdf5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 0000000000000..55d540d8317a0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 1fd6ef4a71253..42e2d9abdeb5e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -68,16 +68,16 @@ {{#applications}} - {{id}} + {{id}} {{name}} {{#attempts}} - {{attemptId}} + {{attemptId}} {{startTime}} {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2a32e18672a22..8fd91865b0429 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -119,7 +119,11 @@ $(document).ready(function() { } } - var data = {"applications": array} + var data = { + "uiroot": uiRoot, + "applications": array + } + $.get("static/historypage-template.html", function(template) { historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); var selector = "#history-summary-table"; @@ -135,6 +139,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860ed..0315ebf5c48a9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e37307aa1f705..0fa1fcf25f8b9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var uiRoot = ""; + +function setUIRoot(val) { + uiRoot = val; +} + function collapseTablePageLoad(name, table){ if (window.localStorage.getItem(name) == "true") { // Set it to false so that the click function can revert it diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e76..af913454fce69 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -139,7 +139,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed65..f83f5278e8b8f 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c97..04d657c09afd0 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd8..1261e3e735761 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -183,6 +183,8 @@ class SparkContext(config: SparkConf) extends Logging { // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") + warnDeprecatedVersions() + /* ------------------------------------------------------------------------------------- * | Private variables. These variables keep the internal state of the context, and are | | not accessible by the outside world. They're mutable since we want to initialize all | @@ -279,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -346,6 +348,16 @@ class SparkContext(config: SparkConf) extends Logging { value } + private def warnDeprecatedVersions(): Unit = { + val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3) + if (javaVersion.length >= 2 && javaVersion(1).toInt == 7) { + logWarning("Support for Java 7 is deprecated as of Spark 2.0.0") + } + if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) { + logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0") + } + } + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN @@ -688,7 +700,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -915,7 +927,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -958,7 +970,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -983,7 +995,7 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1022,7 +1034,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1046,7 +1058,7 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1072,7 +1084,7 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1112,7 +1124,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1138,7 +1150,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1157,7 +1169,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1187,7 +1199,7 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1318,16 +1330,18 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { acc.register(this, name = Some(name)) @@ -1538,7 +1552,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1560,7 +1574,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1578,7 +1592,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1627,7 +1641,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -1716,29 +1731,12 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null - } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null - } + try { + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case exc: FileNotFoundException => + logError(s"Jar not found at $path") + null } // A JAR file which exists locally on every worker node case "local" => @@ -1762,8 +1760,26 @@ class SparkContext(config: SparkConf) extends Logging { */ def listJars(): Seq[String] = addedJars.keySet.toSeq - // Shut down the SparkContext. - def stop() { + /** + * Shut down the SparkContext. + */ + def stop(): Unit = { + if (env.rpcEnv.isInRPCThread) { + // `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread. + // We should launch a new thread to call `stop` to avoid dead-lock. + new Thread("stop-spark-context") { + setDaemon(true) + + override def run(): Unit = { + _stop() + } + }.start() + } else { + _stop() + } + } + + private def _stop() { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") @@ -2285,7 +2301,7 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(config: SparkConf): SparkContext = { @@ -2310,7 +2326,7 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. */ def getOrCreate(): SparkContext = { diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 6550d703bc860..46e22b215b8ee 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -20,14 +20,14 @@ package org.apache.spark import java.io.IOException import java.text.NumberFormat import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.SparkHadoopWriterUtils import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD import org.apache.spark.util.SerializableJobConf @@ -67,12 +67,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -153,29 +153,8 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { splitID = splitid attemptID = attemptid - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) + jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) taID = new SerializableWritable[TaskAttemptID]( new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } } - -private[spark] -object SparkHadoopWriter { - def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") - val jobtrackerID = formatter.format(time) - new JobID(jobtrackerID, id) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } -} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad517..a32a4b28c1731 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -153,7 +153,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +256,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaaf..bff5a29bb60f1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -206,7 +206,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +223,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +234,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +258,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +268,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,7 +404,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -409,7 +415,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ @@ -539,7 +545,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] * will provide much better performance. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7a..ccd94f876e0b8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -98,24 +98,30 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -153,7 +159,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf210..eda16d957cc58 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd53..38d347aeab8c6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -298,7 +298,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +316,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +366,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +396,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +416,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +437,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +458,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +487,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +694,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -811,7 +811,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced0..6aa290ecd7bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 77825e75e5136..fdd8cf62f0e5f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -84,7 +84,6 @@ private[spark] object RUtils { } } else { // Otherwise, assume the package is local - // TODO: support this for Mesos val sparkRPkgPath = localSparkRPackagePath.getOrElse { throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b71138..a4de3d7eaf458 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac33..23156072c3ebe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,7 +21,7 @@ import java.io.IOException import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -357,7 +357,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5c052286099f5..85f80b6971e80 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -41,12 +41,11 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException} -import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL} +import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Whether to submit, kill, or request the status of an application. @@ -63,7 +62,7 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit { +object SparkSubmit extends CommandLineUtils { // Cluster managers private val YARN = 1 @@ -87,15 +86,6 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // scalastyle:off println - // Exposed for testing - private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) - private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn(1) - } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to ____ __ @@ -115,7 +105,7 @@ object SparkSubmit { } // scalastyle:on println - def main(args: Array[String]): Unit = { + override def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { // scalastyle:off println @@ -322,7 +312,7 @@ object SparkSubmit { } // Require all R files to be local - if (args.isR && !isYarnCluster) { + if (args.isR && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } @@ -330,9 +320,6 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + - "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -410,9 +397,9 @@ object SparkSubmit { printErrorAndExit("Distributing R packages with standalone cluster is not supported.") } - // TODO: Support SparkR with mesos cluster - if (args.isR && clusterManager == MESOS) { - printErrorAndExit("SparkR is not supported for Mesos cluster.") + // TODO: Support distributing R packages with mesos cluster + if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -598,6 +585,9 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + } else if (args.isR) { + // Second argument is main class + childArgs += (args.primaryResource, "") } else { childArgs += (args.primaryResource, args.mainClass) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f1761e7c1ec92..b1d36e1821cc7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -412,10 +412,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - value.split("=", 2).toSeq match { - case Seq(k, v) => sparkProperties(k) = v - case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") - } + val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + sparkProperties(confName) = confValue case PROXY_USER => proxyUser = value diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff836466..d7d82800b8b55 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -74,6 +74,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + return 0; + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + return 0; + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c818..ca38a47639422 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -108,7 +108,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +120,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +228,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -329,26 +335,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +388,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14c..0e7a6c24d4fa5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -30,13 +30,30 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +63,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56f..7e21fa681aa1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c4..4618e6117a4fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a969..8b1c6bf2e5fd5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -68,7 +68,7 @@ private[deploy] class Worker( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f66510b6f977f..59404e08895a3 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,6 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} +import org.apache.spark.internal.config +import org.apache.spark.SparkContext + /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -40,9 +43,14 @@ private[spark] abstract class StreamFileInputFormat[T] * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API * which is set through setMaxSplitSize */ - def setMinPartitions(context: JobContext, minPartitions: Int) { - val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum - val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong + def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { + val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultParallelism = sc.defaultParallelism + val files = listStatus(context).asScala + val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum + val bytesPerCore = totalBytes / defaultParallelism + val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) super.setMaxSplitSize(maxSplitSize) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 497ca92c7bc60..2951bdc18bc57 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -206,4 +206,21 @@ package object config { "encountering corrupt files and contents that have been read will still be returned.") .booleanConf .createWithDefault(false) + + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") + .stringConf + .createOptional + + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) + + private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes") + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimate, then the partitions with small files will be faster than partitions with" + + " bigger files.") + .longConf + .createWithDefault(4 * 1024 * 1024) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala new file mode 100644 index 0000000000000..afd2250c93a8a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.util.Utils + + +/** + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. Implementations should have a constructor with either 2 or 3 arguments: + * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 3. A committer should not be reused across multiple Spark jobs. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit +} + + +object FileCommitProtocol { + class TaskCommitMessage(val obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(null) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + : FileCommitProtocol = { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala new file mode 100644 index 0000000000000..b2d9b8d2a012f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.util.{Date, UUID} + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configurable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapReduceCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + val format = context.getOutputFormatClass.newInstance() + // If OutputFormat is Configurable, we should set conf to it. + format match { + case c: Configurable => c.setConf(context.getConfiguration) + case _ => () + } + format.getOutputCommitter(context) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + val filename = getFilename(taskContext, ext) + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapred.job.id", jobId.toString) + jobContext.getConfiguration.set("mapred.tip.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapred.task.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapred.task.is.map", true) + jobContext.getConfiguration.setInt("mapred.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + committer = setupCommitter(taskAttemptContext) + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + committer = setupCommitter(taskContext) + committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + new TaskCommitMessage(addedAbsPathFiles.toMap) + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } + } + + /** Whether we are using a direct output committer */ + def isDirectOutput(): Boolean = committer.getClass.getSimpleName.contains("Direct") +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala new file mode 100644 index 0000000000000..aaeb3d003829a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.reflect.ClassTag +import scala.util.DynamicVariable + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.{SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * A helper object that saves an RDD using a Hadoop OutputFormat + * (from the newer mapreduce API, not the old mapred API). + */ +private[spark] +object SparkHadoopMapReduceWriter extends Logging { + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + hadoopConf: Configuration): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + val sparkConf = rdd.conf + val conf = new SerializableConfiguration(hadoopConf) + + // Set up a job. + val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) + val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) + val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) + val format = jobContext.getOutputFormatClass + + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { + // FileOutputFormat ignores the filesystem parameter + val jobFormat = format.newInstance + jobFormat.checkOutputSpecs(jobContext) + } + + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = stageId.toString, + outputPath = conf.value.get("mapred.output.dir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] + committer.setupJob(jobContext) + + // When speculation is on and output committer class name contains "Direct", we should warn + // users that they may loss data if they are using a direct output committer. + if (SparkHadoopWriterUtils.isSpeculationEnabled(sparkConf) && committer.isDirectOutput) { + val warningMessage = + s"$committer may be an output committer that writes data directly to " + + "the final location. Because speculation is enabled, this output committer may " + + "cause data loss (see the case in SPARK-10063). If possible, please use an output " + + "committer that does not have this behavior (e.g. FileOutputCommitter)." + logWarning(warningMessage) + } + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + hadoopConf = conf.value, + outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write an RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + hadoopConf: Configuration, + outputFormat: Class[_ <: OutputFormat[K, V]], + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, + sparkPartitionId, sparkAttemptNumber) + val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) + committer.setupTask(taskContext) + + val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = + SparkHadoopWriterUtils.initHadoopOutputMetrics(context) + + // Initiate the writer. + val taskFormat = outputFormat.newInstance() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(hadoopConf) + case _ => () + } + val writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[RecordWriter[K, V]] + require(writer != null, "Unable to obtain RecordWriter") + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + while (iterator.hasNext) { + val pair = iterator.next() + writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + SparkHadoopWriterUtils.maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback, recordsWritten) + recordsWritten += 1 + } + + committer.commitTask(taskContext) + }(catchBlock = { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + }, finallyBlock = writer.close(taskContext)) + + outputMetricsAndBytesWrittenCallback.foreach { + case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} + +private[spark] +object SparkHadoopWriterUtils { + + private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + def createJobID(time: Date, id: Int): JobID = { + val jobtrackerID = createJobTrackerID(time) + new JobID(jobtrackerID, id) + } + + def createJobTrackerID(time: Date): String = { + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { + val validationDisabled = disableOutputSpecValidation.value + val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } + + def isSpeculationEnabled(conf: SparkConf): Boolean = { + conf.getBoolean("spark.speculation", false) + } + + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + // return type: (output metrics, bytes written callback), defined only if the latter is defined + def initHadoopOutputMetrics( + context: TaskContext): Option[(OutputMetrics, () => Long)] = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() + bytesWrittenCallback.map { b => + (context.taskMetrics().outputMetrics, b) + } + } + + def maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], + recordsWritten: Long): Unit = { + if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + outputMetricsAndBytesWrittenCallback.foreach { + case (om, callback) => + om.setBytesWritten(callback()) + om.setRecordsWritten(recordsWritten) + } + } + } + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef755..6ba79e506a648 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -32,9 +32,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -103,9 +102,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +122,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +142,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6a..8f579c5a3033c 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 41832e8354741..50d977a92da51 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat private[spark] class BinaryFileRDD[T]( - sc: SparkContext, + @transient private val sc: SparkContext, inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], @@ -43,7 +43,7 @@ private[spark] class BinaryFileRDD[T]( case _ => } val jobContext = new JobContextImpl(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) + inputFormat.setMinPartitions(sc, jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f06..a091f06b4ed7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c57..f3ab324d59119 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -158,7 +158,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de098..86351b8c575e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag @@ -84,9 +84,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +94,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.hadoopRDD()]] */ @DeveloperApi class HadoopRDD[K, V]( @@ -243,7 +243,8 @@ class HadoopRDD[K, V]( var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb658870..a5965f597038d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad745..33e695ec5322b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -18,33 +18,31 @@ package org.apache.spark.rdd import java.nio.ByteBuffer -import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{HashMap => JHashMap} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import scala.util.DynamicVariable import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType} -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.OutputMetrics +import org.apache.spark.internal.io.{FileCommitProtocol, HadoopMapReduceCommitProtocol, SparkHadoopMapReduceWriter, SparkHadoopWriterUtils} import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -59,8 +57,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -68,6 +66,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -363,7 +364,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -490,11 +491,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { @@ -514,11 +515,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { @@ -635,7 +636,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -1016,7 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1060,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } FileOutputFormat.setOutputPath(hadoopConf, - SparkHadoopWriter.createPathFromString(path, hadoopConf)) + SparkHadoopWriterUtils.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) } @@ -1070,86 +1071,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - val jobConfiguration = job.getConfiguration - val wrappedConf = new SerializableConfiguration(jobConfiguration) - val outfmt = job.getOutputFormatClass - val jobFormat = outfmt.newInstance - - if (isOutputSpecValidationEnabled) { - // FileOutputFormat ignores the filesystem parameter - jobFormat.checkOutputSpecs(job) - } - - val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value - /* "reduce task" */ - val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, - context.attemptNumber) - val hadoopContext = new TaskAttemptContextImpl(config, attemptId) - val format = outfmt.newInstance - format match { - case c: Configurable => c.setConf(config) - case _ => () - } - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - - val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = - initHadoopOutputMetrics(context) - - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val pair = iter.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close(hadoopContext)) - committer.commitTask(hadoopContext) - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } - 1 - } : Int - - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) - val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - - // When speculation is on and output committer class name contains "Direct", we should warn - // users that they may loss data if they are using a direct output committer. - val speculationEnabled = self.conf.getBoolean("spark.speculation", false) - val outputCommitterClass = jobCommitter.getClass.getSimpleName - if (speculationEnabled && outputCommitterClass.contains("Direct")) { - val warningMessage = - s"$outputCommitterClass may be an output committer that writes data directly to " + - "the final location. Because speculation is enabled, this output committer may " + - "cause data loss (see the case in SPARK-10063). If possible, please use an output " + - "committer that does not have this behavior (e.g. FileOutputCommitter)." - logWarning(warningMessage) - } - - jobCommitter.setupJob(jobTaskContext) - self.context.runJob(self, writeShard) - jobCommitter.commitJob(jobTaskContext) + SparkHadoopMapReduceWriter.write( + rdd = self, + hadoopConf = conf) } /** @@ -1178,7 +1108,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (isOutputSpecValidationEnabled) { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(hadoopConf) hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) @@ -1193,7 +1123,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = - initHadoopOutputMetrics(context) + SparkHadoopWriterUtils.initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1205,7 +1135,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) // Update bytes written metric every few records - maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) + SparkHadoopWriterUtils.maybeUpdateOutputMetrics( + outputMetricsAndBytesWrittenCallback, recordsWritten) recordsWritten += 1 } }(finallyBlock = writer.close()) @@ -1220,29 +1151,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.commitJob() } - // TODO: these don't seem like the right abstractions. - // We should abstract the duplicate code in a less awkward way. - - // return type: (output metrics, bytes written callback), defined only if the latter is defined - private def initHadoopOutputMetrics( - context: TaskContext): Option[(OutputMetrics, () => Long)] = { - val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - bytesWrittenCallback.map { b => - (context.taskMetrics().outputMetrics, b) - } - } - - private def maybeUpdateOutputMetrics( - outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], - recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } - } - } - /** * Return an RDD with the keys of each tuple. */ @@ -1258,22 +1166,4 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def valueClass: Class[_] = vt.runtimeClass private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) - - // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation - // setting can take effect: - private def isOutputSpecValidationEnabled: Boolean = { - val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value - val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) - enabledInConf && !validationDisabled - } -} - -private[spark] object PairRDDFunctions { - val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 - - /** - * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case - * basis; see SPARK-4835 for more details. - */ - val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee9..ce75a16031a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b9..6a89ea8786464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb3..bff2b8f1d06c9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -428,7 +428,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -471,6 +471,9 @@ abstract class RDD[T: ClassTag]( * without replacement: probability that each element is chosen; fraction must be [0, 1] * with replacement: expected number of times each element is chosen; fraction must be >= 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -534,13 +537,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -615,7 +618,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -627,7 +630,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -643,7 +646,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -671,7 +674,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -684,7 +687,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -699,7 +702,7 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an + * @note This operation may be very expensive. If you are grouping in order to perform an * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. */ @@ -788,14 +791,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { @@ -906,7 +921,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -919,7 +934,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1167,7 +1182,7 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -1257,7 +1272,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1271,7 +1286,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1290,10 +1305,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1355,7 +1370,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1378,7 +1393,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1423,7 +1438,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6bee..1070bb96b2524 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -32,7 +32,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index eac901d10067c..e0a29b48314fb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -151,7 +151,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, @@ -239,12 +239,17 @@ private[spark] object ReliableCheckpointRDD extends Logging { val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) val fileInputStream = fs.open(partitionerFilePath, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - val partitioner = Utils.tryWithSafeFinally[Partitioner] { - deserializeStream.readObject[Partitioner] + val partitioner = Utils.tryWithSafeFinally { + val deserializeStream = serializer.deserializeStream(fileInputStream) + Utils.tryWithSafeFinally { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } } { - deserializeStream.close() + fileInputStream.close() } + logDebug(s"Read partitioner from $partitionerFilePath") Some(partitioner) } catch { diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c71..86a332790fb00 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c63..8425b211d6ecf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 579122868afc8..bbc416381490b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -147,6 +147,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def openChannel(uri: String): ReadableByteChannel + /** + * Return if the current thread is a RPC thread. + */ + def isInRPCThread: Boolean } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index a02cf30a5d831..67baabd2cbff2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** Message loop used for dispatching messages. */ private class MessageLoop extends Runnable { override def run(): Unit = { + NettyRpcEnv.rpcThreadFlag.value = true try { while (true) { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e51649a1ecce9..0b8cd144a2161 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv( } + override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value } private[netty] object NettyRpcEnv extends Logging { + private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false) + /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afec..0a5fe5a1d3ee1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f2517401cb76b..7fde34d8974c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1089,7 +1089,8 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && !updates.isZero) { stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) - event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value)) + event.taskInfo.setAccumulables( + acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 2424586431aa0..0bd5a6bc59a9e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -53,13 +53,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { sourceName: String, maybeTruncated: Boolean = false, eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { + val lines = Source.fromInputStream(logData).getLines() + replay(lines, sourceName, maybeTruncated, eventsFilter) + } + /** + * Overloaded variant of [[replay()]] which accepts an iterator of lines instead of an + * [[InputStream]]. Exposed for use by custom ApplicationHistoryProvider implementations. + */ + def replay( + lines: Iterator[String], + sourceName: String, + maybeTruncated: Boolean, + eventsFilter: ReplayEventsFilter): Unit = { var currentLine: String = null var lineNumber: Int = 0 try { - val lineEntries = Source.fromInputStream(logData) - .getLines() + val lineEntries = lines .zipWithIndex .filter { case (line, _) => eventsFilter(line) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9385e3c31e1e4..d39651a722325 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance @@ -92,8 +93,16 @@ private[spark] abstract class Task[T]( kill(interruptThread = false) } - new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), - Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + new CallerContext( + "TASK", + SparkEnv.get.conf.get(APP_CALLER_CONTEXT), + appId, + appAttemptId, + jobId, + Option(stageId), + Option(stageAttemptId), + Option(taskAttemptId), + Option(attemptNumber)).setCurrentContext() try { runTask(context) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index eeb7963c9e610..59680139e7af3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ListBuffer - import org.apache.spark.TaskState import org.apache.spark.TaskState.TaskState import org.apache.spark.annotation.DeveloperApi @@ -54,7 +52,13 @@ class TaskInfo( * accumulable to be updated multiple times in a single task or for two accumulables with the * same name but different IDs to exist in a task. */ - val accumulables = ListBuffer[AccumulableInfo]() + def accumulables: Seq[AccumulableInfo] = _accumulables + + private[this] var _accumulables: Seq[AccumulableInfo] = Nil + + private[spark] def setAccumulables(newAccumulables: Seq[AccumulableInfo]): Unit = { + _accumulables = newAccumulables + } /** * The time when the task has completed successfully (including the time to remotely fetch diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cff..368cd30a2e11a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -93,7 +93,7 @@ private[spark] class StandaloneSchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val webUrl = sc.ui.map(_.webUrl).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) // If we're using dynamic allocation, set our initial executor limit to 0 for now. // ExecutorAllocationManager will send the real initial limit to the Master later. @@ -103,8 +103,8 @@ private[spark] class StandaloneSchedulerBackend( } else { None } - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) + val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b7..f60dcfddfdc20 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe1076..19e020c968a9a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0ca..afe6cd86059f0 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -40,7 +40,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573db..76af33c1a18db 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd2382225..d8d5e8958b23c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0f..e12f2e6095d5a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -71,7 +71,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +80,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +128,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +142,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,19 +156,22 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f631a047a707d..b828532aba7a3 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -82,7 +82,7 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + environmentListener.systemProperties.toMap.getOrElse("user.name", "") } def getAppName: String = appName @@ -94,16 +94,9 @@ private[spark] class SparkUI private ( /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() - logInfo("Stopped Spark web UI at %s".format(appUIAddress)) + logInfo(s"Stopped Spark web UI at $webUrl") } - /** - * Return the application UI host:port. This does not include the scheme (http://). - */ - private[spark] def appUIHostPort = publicHostName + ":" + boundPort - - private[spark] def appUIAddress = s"http://$appUIHostPort" - def getSparkUI(appId: String): Option[SparkUI] = { if (appId == this.appId) Some(this) else None } @@ -136,7 +129,7 @@ private[spark] class SparkUI private ( private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) extends WebUITab(parent, prefix) { - def appName: String = parent.getAppName + def appName: String = parent.appName } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62a..57f6f2f0a9be5 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -36,7 +36,8 @@ private[spark] object UIUtils extends Logging { // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) @@ -170,6 +171,7 @@ private[spark] object UIUtils extends Logging { + } def vizHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a05e0efb7a3e3..8c801558672fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -56,8 +56,8 @@ private[spark] abstract class WebUI( private val className = Utils.getFormattedClassName(this) def getBasePath: String = basePath - def getTabs: Seq[WebUITab] = tabs.toSeq - def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + def getTabs: Seq[WebUITab] = tabs + def getHandlers: Seq[ServletContextHandler] = handlers def getSecurityManager: SecurityManager = securityManager /** Attach a tab to this UI, along with all of its attached pages. */ @@ -133,7 +133,7 @@ private[spark] abstract class WebUI( def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ - def bind() { + def bind(): Unit = { assert(!serverInfo.isDefined, s"Attempted to bind $className more than once!") try { val host = Option(conf.getenv("SPARK_LOCAL_IP")).getOrElse("0.0.0.0") @@ -156,7 +156,7 @@ private[spark] abstract class WebUI( def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) /** Stop the server behind this web interface. Only valid after bind(). */ - def stop() { + def stop(): Unit = { assert(serverInfo.isDefined, s"Attempted to stop $className before binding to a server!") serverInfo.get.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index a0ef80d9bdae0..c6a07445f2a35 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -48,6 +48,16 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } }.map { thread => val threadId = thread.threadId + val blockedBy = thread.blockedByThreadId match { + case Some(blockedByThreadId) => + + case None => Text("") + } + val heldLocks = thread.holdingLocks.mkString(", ") + {threadId} {thread.threadName} {thread.threadState} + {blockedBy}{heldLocks} {thread.stackTrace} } @@ -86,6 +97,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage Thread ID Thread Name Thread State + Thread Locks {dumpRows} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 173fc3cf31ce8..50e8e2d19e155 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -289,8 +289,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val startTime = listener.startTime val endTime = listener.endTime val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse.toSeq - val failedJobs = listener.failedJobs.reverse.toSeq + val completedJobs = listener.completedJobs.reverse + val failedJobs = listener.failedJobs.reverse val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index f4a04609c4c69..9ce8542f02791 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.{HashMap, LinkedHashMap} import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor.{ShuffleReadMetrics, ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet @@ -147,9 +147,8 @@ private[spark] object UIData { memoryBytesSpilled = m.memoryBytesSpilled, diskBytesSpilled = m.diskBytesSpilled, peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics.bytesRead, m.inputMetrics.recordsRead), - outputMetrics = - OutputMetricsUIData(m.outputMetrics.bytesWritten, m.outputMetrics.recordsWritten), + inputMetrics = InputMetricsUIData(m.inputMetrics), + outputMetrics = OutputMetricsUIData(m.outputMetrics), shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) } @@ -171,9 +170,9 @@ private[spark] object UIData { speculative = taskInfo.speculative ) newTaskInfo.gettingResultTime = taskInfo.gettingResultTime - newTaskInfo.accumulables ++= taskInfo.accumulables.filter { + newTaskInfo.setAccumulables(taskInfo.accumulables.filter { accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) - } + }) newTaskInfo.finishTime = taskInfo.finishTime newTaskInfo.failed = taskInfo.failed newTaskInfo @@ -197,8 +196,32 @@ private[spark] object UIData { shuffleWriteMetrics: ShuffleWriteMetricsUIData) case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) + object InputMetricsUIData { + def apply(metrics: InputMetrics): InputMetricsUIData = { + if (metrics.bytesRead == 0 && metrics.recordsRead == 0) { + EMPTY + } else { + new InputMetricsUIData( + bytesRead = metrics.bytesRead, + recordsRead = metrics.recordsRead) + } + } + private val EMPTY = InputMetricsUIData(0, 0) + } case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long) + object OutputMetricsUIData { + def apply(metrics: OutputMetrics): OutputMetricsUIData = { + if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) { + EMPTY + } else { + new OutputMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten) + } + } + private val EMPTY = OutputMetricsUIData(0, 0) + } case class ShuffleReadMetricsUIData( remoteBlocksFetched: Long, @@ -212,17 +235,30 @@ private[spark] object UIData { object ShuffleReadMetricsUIData { def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = { - new ShuffleReadMetricsUIData( - remoteBlocksFetched = metrics.remoteBlocksFetched, - localBlocksFetched = metrics.localBlocksFetched, - remoteBytesRead = metrics.remoteBytesRead, - localBytesRead = metrics.localBytesRead, - fetchWaitTime = metrics.fetchWaitTime, - recordsRead = metrics.recordsRead, - totalBytesRead = metrics.totalBytesRead, - totalBlocksFetched = metrics.totalBlocksFetched - ) + if ( + metrics.remoteBlocksFetched == 0 && + metrics.localBlocksFetched == 0 && + metrics.remoteBytesRead == 0 && + metrics.localBytesRead == 0 && + metrics.fetchWaitTime == 0 && + metrics.recordsRead == 0 && + metrics.totalBytesRead == 0 && + metrics.totalBlocksFetched == 0) { + EMPTY + } else { + new ShuffleReadMetricsUIData( + remoteBlocksFetched = metrics.remoteBlocksFetched, + localBlocksFetched = metrics.localBlocksFetched, + remoteBytesRead = metrics.remoteBytesRead, + localBytesRead = metrics.localBytesRead, + fetchWaitTime = metrics.fetchWaitTime, + recordsRead = metrics.recordsRead, + totalBytesRead = metrics.totalBytesRead, + totalBlocksFetched = metrics.totalBlocksFetched + ) + } } + private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0) } case class ShuffleWriteMetricsUIData( @@ -232,12 +268,17 @@ private[spark] object UIData { object ShuffleWriteMetricsUIData { def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = { - new ShuffleWriteMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten, - writeTime = metrics.writeTime - ) + if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) { + EMPTY + } else { + new ShuffleWriteMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten, + writeTime = metrics.writeTime + ) + } } + private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0) } } diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d3ddd39131326..1326f0977c241 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -59,8 +59,9 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } /** - * Returns true if this accumulator has been registered. Note that all accumulators must be - * registered before use, or it will throw exception. + * Returns true if this accumulator has been registered. + * + * @note All accumulators must be registered before use, or it will throw exception. */ final def isRegistered: Boolean = metadata != null && AccumulatorContext.get(metadata.id).isDefined diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala new file mode 100644 index 0000000000000..d73901686b705 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +import org.apache.spark.SparkException + +/** + * Contains basic command line parsing functionality and methods to parse some common Spark CLI + * options. + */ +private[spark] trait CommandLineUtils { + + // Exposed for testing + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) + + private[spark] var printStream: PrintStream = System.err + + // scalastyle:off println + + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + + private[spark] def printErrorAndExit(str: String): Unit = { + printStream.println("Error: " + str) + printStream.println("Run with --help for usage help or --verbose for debug output") + exitFn(1) + } + + // scalastyle:on println + + private[spark] def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => printErrorAndExit(s"Spark config without '=': $pair") + throw new SparkException(s"Spark config without '=': $pair") + } + } + + def main(args: Array[String]): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c11eb3ffa4601..4b4d2d10cbf8d 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -107,20 +107,20 @@ private[spark] object JsonProtocol { def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted): JValue = { val stageInfo = stageInfoToJson(stageSubmitted.stageInfo) val properties = propertiesToJson(stageSubmitted.properties) - ("Event" -> Utils.getFormattedClassName(stageSubmitted)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) ~ ("Stage Info" -> stageInfo) ~ ("Properties" -> properties) } def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = { val stageInfo = stageInfoToJson(stageCompleted.stageInfo) - ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) ~ ("Stage Info" -> stageInfo) } def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = { val taskInfo = taskStart.taskInfo - ("Event" -> Utils.getFormattedClassName(taskStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) ~ ("Stage ID" -> taskStart.stageId) ~ ("Stage Attempt ID" -> taskStart.stageAttemptId) ~ ("Task Info" -> taskInfoToJson(taskInfo)) @@ -128,7 +128,7 @@ private[spark] object JsonProtocol { def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = { val taskInfo = taskGettingResult.taskInfo - ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) ~ ("Task Info" -> taskInfoToJson(taskInfo)) } @@ -137,7 +137,7 @@ private[spark] object JsonProtocol { val taskInfo = taskEnd.taskInfo val taskMetrics = taskEnd.taskMetrics val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing - ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) ~ ("Stage ID" -> taskEnd.stageId) ~ ("Stage Attempt ID" -> taskEnd.stageAttemptId) ~ ("Task Type" -> taskEnd.taskType) ~ @@ -148,7 +148,7 @@ private[spark] object JsonProtocol { def jobStartToJson(jobStart: SparkListenerJobStart): JValue = { val properties = propertiesToJson(jobStart.properties) - ("Event" -> Utils.getFormattedClassName(jobStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) ~ ("Job ID" -> jobStart.jobId) ~ ("Submission Time" -> jobStart.time) ~ ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 @@ -158,7 +158,7 @@ private[spark] object JsonProtocol { def jobEndToJson(jobEnd: SparkListenerJobEnd): JValue = { val jobResult = jobResultToJson(jobEnd.jobResult) - ("Event" -> Utils.getFormattedClassName(jobEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobEnd) ~ ("Job ID" -> jobEnd.jobId) ~ ("Completion Time" -> jobEnd.time) ~ ("Job Result" -> jobResult) @@ -170,7 +170,7 @@ private[spark] object JsonProtocol { val sparkProperties = mapToJson(environmentDetails("Spark Properties").toMap) val systemProperties = mapToJson(environmentDetails("System Properties").toMap) val classpathEntries = mapToJson(environmentDetails("Classpath Entries").toMap) - ("Event" -> Utils.getFormattedClassName(environmentUpdate)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.environmentUpdate) ~ ("JVM Information" -> jvmInformation) ~ ("Spark Properties" -> sparkProperties) ~ ("System Properties" -> systemProperties) ~ @@ -179,7 +179,7 @@ private[spark] object JsonProtocol { def blockManagerAddedToJson(blockManagerAdded: SparkListenerBlockManagerAdded): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId) - ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~ ("Block Manager ID" -> blockManagerId) ~ ("Maximum Memory" -> blockManagerAdded.maxMem) ~ ("Timestamp" -> blockManagerAdded.time) @@ -187,18 +187,18 @@ private[spark] object JsonProtocol { def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId) - ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerRemoved) ~ ("Block Manager ID" -> blockManagerId) ~ ("Timestamp" -> blockManagerRemoved.time) } def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = { - ("Event" -> Utils.getFormattedClassName(unpersistRDD)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.unpersistRDD) ~ ("RDD ID" -> unpersistRDD.rddId) } def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = { - ("Event" -> Utils.getFormattedClassName(applicationStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationStart) ~ ("App Name" -> applicationStart.appName) ~ ("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~ ("Timestamp" -> applicationStart.time) ~ @@ -208,33 +208,33 @@ private[spark] object JsonProtocol { } def applicationEndToJson(applicationEnd: SparkListenerApplicationEnd): JValue = { - ("Event" -> Utils.getFormattedClassName(applicationEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationEnd) ~ ("Timestamp" -> applicationEnd.time) } def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { - ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorAdded) ~ ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { - ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorRemoved) ~ ("Timestamp" -> executorRemoved.time) ~ ("Executor ID" -> executorRemoved.executorId) ~ ("Removed Reason" -> executorRemoved.reason) } def logStartToJson(logStart: SparkListenerLogStart): JValue = { - ("Event" -> Utils.getFormattedClassName(logStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.logStart) ~ ("Spark Version" -> SPARK_VERSION) } def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId val accumUpdates = metricsUpdate.accumUpdates - ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~ ("Executor ID" -> execId) ~ ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ @@ -485,7 +485,7 @@ private[spark] object JsonProtocol { * JSON deserialization methods for SparkListenerEvents | * ---------------------------------------------------- */ - def sparkEventFromJson(json: JValue): SparkListenerEvent = { + private object SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES { val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted) val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted) val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart) @@ -503,6 +503,10 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + } + + def sparkEventFromJson(json: JValue): SparkListenerEvent = { + import SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES._ (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -540,7 +544,8 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] - val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val stageAttemptId = + Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -552,7 +557,8 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] - val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val stageAttemptId = + Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -662,20 +668,22 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = (json \ "Details").extractOpt[String].getOrElse("") + val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) - val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson) - case None => Seq[AccumulableInfo]() + val accumulatedValues = { + Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + case Some(values) => values.map(accumulableInfoFromJson) + case None => Seq[AccumulableInfo]() + } } val stageInfo = new StageInfo( @@ -692,17 +700,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1) + val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] - val executorId = (json \ "Executor ID").extract[String] - val host = (json \ "Host").extract[String] + val executorId = (json \ "Executor ID").extract[String].intern() + val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false) + val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = (json \ "Killed").extractOpt[Boolean].getOrElse(false) - val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match { + val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -713,18 +721,19 @@ private[spark] object JsonProtocol { taskInfo.finishTime = finishTime taskInfo.failed = failed taskInfo.killed = killed - accumulables.foreach { taskInfo.accumulables += _ } + taskInfo.setAccumulables(accumulables) taskInfo } def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = (json \ "Name").extractOpt[String] + val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) - val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) - val metadata = (json \ "Metadata").extractOpt[String] + val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val countFailedValues = + Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -782,9 +791,11 @@ private[spark] object JsonProtocol { readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incLocalBytesRead( + Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) - readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incRecordsRead( + Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } @@ -793,8 +804,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) - writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written") - .extractOpt[Long].getOrElse(0L)) + writeMetrics.incRecordsWritten( + Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } @@ -802,14 +813,16 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) - outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L)) + outputMetrics.setRecordsWritten( + Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incRecordsRead( + Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks @@ -824,7 +837,7 @@ private[spark] object JsonProtocol { metrics } - def taskEndReasonFromJson(json: JValue): TaskEndReason = { + private object TASK_END_REASON_FORMATTED_CLASS_NAMES { val success = Utils.getFormattedClassName(Success) val resubmitted = Utils.getFormattedClassName(Resubmitted) val fetchFailed = Utils.getFormattedClassName(FetchFailed) @@ -834,6 +847,10 @@ private[spark] object JsonProtocol { val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) + } + + def taskEndReasonFromJson(json: JValue): TaskEndReason = { + import TASK_END_REASON_FORMATTED_CLASS_NAMES._ (json \ "Reason").extract[String] match { case `success` => Success @@ -850,7 +867,8 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") - val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull + val fullStackTrace = + Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) @@ -885,15 +903,19 @@ private[spark] object JsonProtocol { if (json == JNothing) { return null } - val executorId = (json \ "Executor ID").extract[String] - val host = (json \ "Host").extract[String] + val executorId = (json \ "Executor ID").extract[String].intern() + val host = (json \ "Host").extract[String].intern() val port = (json \ "Port").extract[Int] BlockManagerId(executorId, host, port) } - def jobResultFromJson(json: JValue): JobResult = { + private object JOB_RESULT_FORMATTED_CLASS_NAMES { val jobSucceeded = Utils.getFormattedClassName(JobSucceeded) val jobFailed = Utils.getFormattedClassName(JobFailed) + } + + def jobResultFromJson(json: JValue): JobResult = { + import JOB_RESULT_FORMATTED_CLASS_NAMES._ (json \ "Result").extract[String] match { case `jobSucceeded` => JobSucceeded diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala index d4e0ad93b966a..b1217980faf1f 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace( threadId: Long, threadName: String, threadState: Thread.State, - stackTrace: String) + stackTrace: String, + blockedByThreadId: Option[Long], + blockedByLock: String, + holdingLocks: Seq[String]) + diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6027b07c0fee8..748d729554fca 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.lang.management.ManagementFactory +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -39,6 +39,7 @@ import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses @@ -1418,8 +1419,12 @@ private[spark] object Utils extends Logging { } callStack(0) = ste.toString // Put last Spark method on top of the stack trace. } else { - firstUserLine = ste.getLineNumber - firstUserFile = ste.getFileName + if (ste.getFileName != null) { + firstUserFile = ste.getFileName + if (ste.getLineNumber >= 0) { + firstUserLine = ste.getLineNumber + } + } callStack += ste.toString insideSpark = false } @@ -2051,6 +2056,20 @@ private[spark] object Utils extends Logging { path } + /** + * Updates Spark config with properties from a set of Properties. + * Provided properties have the highest priority. + */ + def updateSparkConfigFromProperties( + conf: SparkConf, + properties: Map[String, String]) : Unit = { + properties.filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.set(k, v) + } + } + /** Load properties present in the given file. */ def getPropertiesFromFile(filename: String): Map[String, String] = { val file = new File(filename) @@ -2096,15 +2115,41 @@ private[spark] object Utils extends Logging { } } + private implicit class Lock(lock: LockInfo) { + def lockString: String = { + lock match { + case monitor: MonitorInfo => + s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" + case _ => + s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" + } + } + } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ def getThreadDump(): Array[ThreadStackTrace] = { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, - threadInfo.getThreadState, stackTrace) + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = (threadInfo.getLockedSynchronizers.map(_.lockString) + ++ threadInfo.getLockedMonitors.map(_.lockString) + ).toSet + + ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, threadInfo.getThreadState, + stackTrace, if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), heldLocks.toSeq) } } @@ -2196,6 +2241,9 @@ private[spark] object Utils extends Logging { isBindCollision(e.getCause) case e: MultiException => e.getThrowables.asScala.exists(isBindCollision) + case e: NativeIoException => + (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) || + isBindCollision(e.getCause) case e: Exception => isBindCollision(e.getCause) case _ => false } @@ -2513,6 +2561,8 @@ private[util] object CallerContext extends Logging { val callerContextSupported: Boolean = { SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { try { + // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in + // master Maven build, so do not use it before resolving SPARK-17714. // scalastyle:off classforname Class.forName("org.apache.hadoop.ipc.CallerContext") Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") @@ -2541,6 +2591,7 @@ private[util] object CallerContext extends Logging { * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) * * The parameters below are optional: + * @param upstreamCallerContext caller context the upstream application passes in * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to * @param jobId id of the job this task belongs to @@ -2550,26 +2601,38 @@ private[util] object CallerContext extends Logging { * @param taskAttemptNumber task attempt id */ private[spark] class CallerContext( - from: String, - appId: Option[String] = None, - appAttemptId: Option[String] = None, - jobId: Option[Int] = None, - stageId: Option[Int] = None, - stageAttemptId: Option[Int] = None, - taskId: Option[Long] = None, - taskAttemptNumber: Option[Int] = None) extends Logging { - - val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" - val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" - val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" - val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" - val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" - val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" - val taskAttemptNumberStr = - if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" - - val context = "SPARK_" + from + appIdStr + appAttemptIdStr + - jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr + from: String, + upstreamCallerContext: Option[String] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None, + jobId: Option[Int] = None, + stageId: Option[Int] = None, + stageAttemptId: Option[Int] = None, + taskId: Option[Long] = None, + taskAttemptNumber: Option[Int] = None) extends Logging { + + private val context = prepareContext("SPARK_" + + from + + appId.map("_" + _).getOrElse("") + + appAttemptId.map("_" + _).getOrElse("") + + jobId.map("_JId_" + _).getOrElse("") + + stageId.map("_SId_" + _).getOrElse("") + + stageAttemptId.map("_" + _).getOrElse("") + + taskId.map("_TId_" + _).getOrElse("") + + taskAttemptNumber.map("_" + _).getOrElse("") + + upstreamCallerContext.map("_" + _).getOrElse("")) + + private def prepareContext(context: String): String = { + // The default max size of Hadoop caller context is 128 + lazy val len = SparkHadoopUtil.get.conf.getInt("hadoop.caller.context.max.size", 128) + if (context == null || context.length <= len) { + context + } else { + val finalContext = context.substring(0, len) + logWarning(s"Truncated Spark caller context from $context to $finalContext") + finalContext + } + } /** * Set up the caller context [[context]] by invoking Hadoop CallerContext API of @@ -2578,6 +2641,8 @@ private[spark] class CallerContext( def setCurrentContext(): Unit = { if (CallerContext.callerContextSupported) { try { + // `Utils.classForName` will make `ReplSuite` fail with `ClassCircularityError` in + // master Maven build, so do not use it before resolving SPARK-17714. // scalastyle:off classforname val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 6b74a29aceda9..bcb95b416dd25 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -140,16 +140,16 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k.equals(curKey)) { - val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) - data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] - return newValue - } else if (curKey.eq(null)) { + if (curKey.eq(null)) { val newValue = updateFunc(false, null.asInstanceOf[V]) data(2 * pos) = k data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] incrementSize() return newValue + } else if (k.eq(curKey) || k.equals(curKey)) { + val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue } else { val delta = i pos = (pos + delta) & mask diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 0f6a425e3db9a..60f6f537c1d54 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -48,7 +48,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( require(initialCapacity <= OpenHashSet.MAX_CAPACITY, s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements") - require(initialCapacity >= 1, "Invalid initial capacity") + require(initialCapacity >= 0, "Invalid initial capacity") require(loadFactor < 1.0, "Load factor must be less than 1.0") require(loadFactor > 0.0, "Load factor must be greater than 0.0") @@ -271,8 +271,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() private def nextPowerOf2(n: Int): Int = { - val highBit = Integer.highestOneBit(n) - if (highBit == n) n else highBit << 1 + if (n == 0) { + 1 + } else { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } } } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 5c4238c0381a1..1f263df57c857 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -18,7 +18,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.internal.Logging @@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy( } @volatile private var nextRolloverTime = calculateNextRolloverTime() - private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US) /** Should rollover if current time has exceeded next rollover time */ def shouldRollover(bytesToBeWritten: Long): Boolean = { @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US) /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 533025ba83e72..7bebe0612f9a8 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -20,7 +20,6 @@ import java.io.*; import java.nio.channels.FileChannel; import java.nio.ByteBuffer; -import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -46,6 +45,7 @@ import com.google.common.collect.Lists; import com.google.common.base.Throwables; import com.google.common.io.Files; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; @@ -1075,18 +1075,23 @@ public void wholeTextFiles() throws Exception { byte[] content2 = "spark is also easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); - Files.write(content1, new File(tempDirName + "/part-00000")); - Files.write(content2, new File(tempDirName + "/part-00001")); + String path1 = new Path(tempDirName, "part-00000").toUri().getPath(); + String path2 = new Path(tempDirName, "part-00001").toUri().getPath(); + + Files.write(content1, new File(path1)); + Files.write(content2, new File(path2)); Map container = new HashMap<>(); - container.put(tempDirName+"/part-00000", new Text(content1).toString()); - container.put(tempDirName+"/part-00001", new Text(content2).toString()); + container.put(path1, new Text(content1).toString()); + container.put(path2, new Text(content2).toString()); JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { - assertEquals(res._2(), container.get(new URI(res._1()).getPath())); + // Note that the paths from `wholeTextFiles` are in URI format on Windows, + // for example, file:/C:/a/b/c. + assertEquals(res._2(), container.get(new Path(res._1()).toUri().getPath())); } } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index cc52bb1d23cd5..89f0b1cb5b56a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -58,10 +58,15 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { nums.saveAsTextFile(outputDir) // Read the plain text file and check it's OK val outputFile = new File(outputDir, "part-00000") - val content = Source.fromFile(outputFile).mkString - assert(content === "1\n2\n3\n4\n") - // Also try reading it in as a text file RDD - assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + val bufferSrc = Source.fromFile(outputFile) + Utils.tryWithSafeFinally { + val content = bufferSrc.mkString + assert(content === "1\n2\n3\n4\n") + // Also try reading it in as a text file RDD + assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + } { + bufferSrc.close() + } } test("text files (compressed)") { diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 13cba94578a6a..005587051b6ad 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{ResetSystemProperties, Utils} class RPackageUtilsSuite extends SparkFunSuite @@ -74,9 +74,13 @@ class RPackageUtilsSuite val deps = Seq(dep1, dep2).mkString(",") IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => val jars = Seq(main, dep1, dep2).map(c => new JarFile(getJarPath(c, new File(new URI(repo))))) - assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") - assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") - assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + Utils.tryWithSafeFinally { + assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") + assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") + assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + } { + jars.foreach(_.close()) + } } } @@ -131,7 +135,7 @@ class RPackageUtilsSuite test("SparkR zipping works properly") { val tempDir = Files.createTempDir() - try { + Utils.tryWithSafeFinally { IvyTestUtils.writeFile(tempDir, "test.R", "abc") val fakeSparkRDir = new File(tempDir, "SparkR") assert(fakeSparkRDir.mkdirs()) @@ -144,14 +148,19 @@ class RPackageUtilsSuite IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") assert(finalZip.exists()) - val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq - assert(entries.contains("/test.R")) - assert(entries.contains("/SparkR/abc.R")) - assert(entries.contains("/SparkR/DESCRIPTION")) - assert(!entries.contains("/package.zip")) - assert(entries.contains("/packageTest/def.R")) - assert(entries.contains("/packageTest/DESCRIPTION")) - } finally { + val zipFile = new ZipFile(finalZip) + Utils.tryWithSafeFinally { + val entries = zipFile.entries().asScala.map(_.getName).toSeq + assert(entries.contains("/test.R")) + assert(entries.contains("/SparkR/abc.R")) + assert(entries.contains("/SparkR/DESCRIPTION")) + assert(!entries.contains("/package.zip")) + assert(entries.contains("/packageTest/def.R")) + assert(entries.contains("/packageTest/DESCRIPTION")) + } { + zipFile.close() + } + } { FileUtils.deleteDirectory(tempDir) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c649e305a37e..626888022903b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -34,21 +34,11 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} -// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch -// of properties that needed to be cleared after tests. -class SparkSubmitSuite - extends SparkFunSuite - with Matchers - with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { - override def beforeEach() { - super.beforeEach() - System.setProperty("spark.testing", "true") - } +trait TestPrematureExit { + suite: SparkFunSuite => private val noOpOutputStream = new OutputStream { def write(b: Int) = {} @@ -65,16 +55,19 @@ class SparkSubmitSuite } /** Returns true if the script exits and the given search string is printed. */ - private def testPrematureExit(input: Array[String], searchString: String) = { + private[spark] def testPrematureExit( + input: Array[String], + searchString: String, + mainObject: CommandLineUtils = SparkSubmit) : Unit = { val printStream = new BufferPrintStream() - SparkSubmit.printStream = printStream + mainObject.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = (_) => exitedCleanly = true + mainObject.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { - SparkSubmit.main(input) + mainObject.main(input) } catch { // If exceptions occur after the "exit" has happened, fine to ignore them. // These represent code paths not reachable during normal execution. @@ -88,6 +81,22 @@ class SparkSubmitSuite fail(s"Search string '$searchString' not found in $joined") } } +} + +// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch +// of properties that needed to be cleared after tests. +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts + with TestPrematureExit { + + override def beforeEach() { + super.beforeEach() + System.setProperty("spark.testing", "true") + } // scalastyle:off println test("prints usage on empty input") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index a5eda7b5a5a75..2c41c432d1fe2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -449,8 +449,14 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) val bstream = new BufferedOutputStream(cstream) if (isNewFormat) { - EventLoggingListener.initEventLog(new FileOutputStream(file)) + val newFormatStream = new FileOutputStream(file) + Utils.tryWithSafeFinally { + EventLoggingListener.initEventLog(newFormatStream) + } { + newFormatStream.close() + } } + val writer = new OutputStreamWriter(bstream, StandardCharsets.UTF_8) Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a595bc174a310..715811a46f42d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -29,6 +29,8 @@ import com.codahale.metrics.Counter import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.eclipse.jetty.proxy.ProxyServlet +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ @@ -258,8 +260,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { - val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) val request = mock[HttpServletRequest] @@ -267,7 +268,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // when System.setProperty("spark.ui.proxyBase", uiRoot) val response = page.render(request) - System.setProperty("spark.ui.proxyBase", Option(proxyBaseBeforeTest).getOrElse("")) // then val urls = response \\ "@href" map (_.toString) @@ -275,6 +275,80 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + val uiRoot = "/testwebproxybase" + System.setProperty("spark.ui.proxyBase", uiRoot) + + server.stop() + + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", logDir) + .set("spark.history.fs.update.interval", "0") + .set("spark.testing", "true") + + provider = new FsHistoryProvider(conf) + provider.checkForLogs() + val securityManager = new SecurityManager(conf) + + server = new HistoryServer(conf, provider, securityManager, 18080) + server.initialize() + server.bind() + + val port = server.boundPort + + val servlet = new ProxyServlet { + override def rewriteTarget(request: HttpServletRequest): String = { + // servlet acts like a proxy that redirects calls made on + // spark.ui.proxyBase context path to the normal servlet handlers operating off "/" + val sb = request.getRequestURL() + + if (request.getQueryString() != null) { + sb.append(s"?${request.getQueryString()}") + } + + val proxyidx = sb.indexOf(uiRoot) + sb.delete(proxyidx, proxyidx + uiRoot.length).toString + } + } + + val contextHandler = new ServletContextHandler + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(uiRoot) + contextHandler.addServlet(holder, "/") + server.attachHandler(contextHandler) + + implicit val webDriver: WebDriver = new HtmlUnitDriver(true) { + getWebClient.getOptions.setThrowExceptionOnScriptError(false) + } + + try { + val url = s"http://localhost:$port" + + go to s"$url$uiRoot" + + // expect the ajax call to finish in 5 seconds + implicitlyWait(org.scalatest.time.Span(5, org.scalatest.time.Seconds)) + + // once this findAll call returns, we know the ajax load of the table completed + findAll(ClassNameQuery("odd")) + + val links = findAll(TagNameQuery("a")) + .map(_.attribute("href")) + .filter(_.isDefined) + .map(_.get) + .filter(_.startsWith(url)).toList + + // there are atleast some URL links that were generated via javascript, + // and they all contain the spark.ui.proxyBase (uiRoot) + links.length should be > 4 + all(links) should startWith(url + uiRoot) + } finally { + contextHandler.stop() + quit() + } + + } + test("incomplete apps get refreshed") { implicit val webDriver: WebDriver = new HtmlUnitDriver diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index b0d69de6e2ef4..02df157be377c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -516,10 +516,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") /* - Check that configurable formats get configured: - ConfigTestFormat throws an exception if we try to write - to it when setConf hasn't been called first. - Assertion is in ConfigTestFormat.getRecordWriter. + * Check that configurable formats get configured: + * ConfigTestFormat throws an exception if we try to write + * to it when setConf hasn't been called first. + * Assertion is in ConfigTestFormat.getRecordWriter. */ pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") } @@ -544,7 +544,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val e = intercept[SparkException] { pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored") } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") @@ -725,8 +725,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } /* - These classes are fakes for testing - "saveNewAPIHadoopFile should call setConf if format is configurable". + These classes are fakes for testing saveAsHadoopFile/saveNewAPIHadoopFile. Unfortunately, they have to be top level classes, and not defined in the test method, because otherwise Scala won't generate no-args constructors and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index acdf21df9a161..aa0705987d837 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { verify(endpoint, never()).onDisconnected(any()) verify(endpoint, never()).onNetworkError(any(), any()) } + + test("isInRPCThread") { + val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply(rpcEnv.isInRPCThread) + } + }) + assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true) + assert(env.isInRPCThread === false) + env.stop(rpcEndpointRef) + } } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bec95d13d193a..5e8a854e46a0f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2076,7 +2076,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } /** - * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * Checks the DAGScheduler's internal logic for traversing an RDD DAG by making sure that * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s * denotes a shuffle dependency): diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 7f4859206e257..8a5ec37eeb66c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -202,8 +202,6 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // Make sure expected events exist in the log file. val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) - val logStart = SparkListenerLogStart(SPARK_VERSION) - val lines = readLines(logData) val eventSet = mutable.Set( SparkListenerApplicationStart, SparkListenerBlockManagerAdded, @@ -216,19 +214,25 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerTaskStart, SparkListenerTaskEnd, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) - lines.foreach { line => - eventSet.foreach { event => - if (line.contains(event)) { - val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line)) - val eventType = Utils.getFormattedClassName(parsedEvent) - if (eventType == event) { - eventSet.remove(event) + Utils.tryWithSafeFinally { + val logStart = SparkListenerLogStart(SPARK_VERSION) + val lines = readLines(logData) + lines.foreach { line => + eventSet.foreach { event => + if (line.contains(event)) { + val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line)) + val eventType = Utils.getFormattedClassName(parsedEvent) + if (eventType == event) { + eventSet.remove(event) + } } } } + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) + } { + logData.close() } - assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) - assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) } private def readLines(in: InputStream): Seq[String] = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 9e472f900b655..ee95e4ff7dbc3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -183,9 +183,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // ensure we reset the classloader after the test completes val originalClassLoader = Thread.currentThread.getContextClassLoader - try { + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + Utils.tryWithSafeFinally { // load the exception from the jar - val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) loader.addURL(jarFile.toURI.toURL) Thread.currentThread().setContextClassLoader(loader) val excClass: Class[_] = Utils.classForName("repro.MyException") @@ -209,8 +209,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) - } finally { + } { Thread.currentThread.setContextClassLoader(originalClassLoader) + loader.close() } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e5d408a167361..f4786e3931c94 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -473,7 +473,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -486,7 +486,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST getResponseCode(url, "GET") should be (200) getResponseCode(url, "POST") should be (200) @@ -620,7 +620,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B test("live UI json application list") { withSpark(newSparkContext()) { sc => val appListRawJson = HistoryServerSuite.getUrl(new URL( - sc.ui.get.appUIAddress + "/api/v1/applications")) + sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) val attempts = (appListJsonAst \ "attempts").children @@ -640,7 +640,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + val stage0 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) @@ -651,7 +651,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("{\n label="groupBy";\n " + "2 [label="MapPartitionsRDD [2]")) - val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + val stage1 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) @@ -662,7 +662,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage1.contains("{\n label="groupBy";\n " + "5 [label="MapPartitionsRDD [5]")) - val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + val stage2 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) @@ -687,7 +687,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def goToUi(ui: SparkUI, path: String): Unit = { - go to (ui.appUIAddress.stripSuffix("/") + path) + go to (ui.webUrl.stripSuffix("/") + path) } def parseDate(json: JValue): Long = { @@ -699,6 +699,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) + new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 4abcfb7e51914..68c7657cb315b 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -66,7 +66,7 @@ class UISuite extends SparkFunSuite { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.get.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) assert(html.toLowerCase.contains("stages")) assert(html.toLowerCase.contains("storage")) @@ -176,19 +176,18 @@ class UISuite extends SparkFunSuite { } } - test("verify appUIAddress contains the scheme") { + test("verify webUrl contains the scheme") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val uiAddress = ui.appUIAddress - val uiHostPort = ui.appUIHostPort - assert(uiAddress.equals("http://" + uiHostPort)) + val uiAddress = ui.webUrl + assert(uiAddress.startsWith("http://") || uiAddress.startsWith("https://")) } } - test("verify appUIAddress contains the port") { + test("verify webUrl contains the port") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val splitUIAddress = ui.appUIAddress.split(':') + val splitUIAddress = ui.webUrl.split(':') val boundPort = ui.boundPort assert(splitUIAddress(2).toInt == boundPort) } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 8418fa74d2c63..da853f1be8b95 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -403,7 +403,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with internal = false, countFailedValues = false, metadata = None) - taskInfo.accumulables ++= Seq(internalAccum, sqlAccum, userAccum) + taskInfo.setAccumulables(List(internalAccum, sqlAccum, userAccum)) val newTaskInfo = TaskUIData.dropInternalAndSQLAccumulables(taskInfo) assert(newTaskInfo.accumulables === Seq(userAccum)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index d5146d70ebaa3..85da79180fd0b 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -788,11 +788,8 @@ private[spark] object JsonProtocolSuite extends Assertions { private def makeTaskInfo(a: Long, b: Int, c: Int, d: Long, speculative: Boolean) = { val taskInfo = new TaskInfo(a, b, c, d, "executor", "your kind sir", TaskLocality.NODE_LOCAL, speculative) - val (acc1, acc2, acc3) = - (makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true)) - taskInfo.accumulables += acc1 - taskInfo.accumulables += acc2 - taskInfo.accumulables += acc3 + taskInfo.setAccumulables( + List(makeAccumulableInfo(1), makeAccumulableInfo(2), makeAccumulableInfo(3, internal = true))) taskInfo } diff --git a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala index 8b53d4f14a6a4..f6ac89fc2742a 100644 --- a/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/MutableURLClassLoaderSuite.scala @@ -51,6 +51,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { assert(fakeClassVersion === "1") val fakeClass2 = classLoader.loadClass("FakeClass2").newInstance() assert(fakeClass.getClass === fakeClass2.getClass) + classLoader.close() + parentLoader.close() } test("parent first") { @@ -61,6 +63,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { assert(fakeClassVersion === "2") val fakeClass2 = classLoader.loadClass("FakeClass1").newInstance() assert(fakeClass.getClass === fakeClass2.getClass) + classLoader.close() + parentLoader.close() } test("child first can fall back") { @@ -69,6 +73,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val fakeClass = classLoader.loadClass("FakeClass3").newInstance() val fakeClassVersion = fakeClass.toString assert(fakeClassVersion === "2") + classLoader.close() + parentLoader.close() } test("child first can fail") { @@ -77,6 +83,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { intercept[java.lang.ClassNotFoundException] { classLoader.loadClass("FakeClassDoesNotExist").newInstance() } + classLoader.close() + parentLoader.close() } test("default JDK classloader get resources") { @@ -84,6 +92,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val classLoader = new URLClassLoader(fileUrlsChild, parentLoader) assert(classLoader.getResources("resource1").asScala.size === 2) assert(classLoader.getResources("resource2").asScala.size === 1) + classLoader.close() + parentLoader.close() } test("parent first get resources") { @@ -91,6 +101,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { val classLoader = new MutableURLClassLoader(fileUrlsChild, parentLoader) assert(classLoader.getResources("resource1").asScala.size === 2) assert(classLoader.getResources("resource2").asScala.size === 1) + classLoader.close() + parentLoader.close() } test("child first get resources") { @@ -103,6 +115,8 @@ class MutableURLClassLoaderSuite extends SparkFunSuite with Matchers { res1.map(scala.io.Source.fromURL(_).mkString) should contain inOrderOnly ("resource1Contents-child", "resource1Contents-parent") + classLoader.close() + parentLoader.close() } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 15ef32f21d90c..feacfb7642f27 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -264,7 +264,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val hour = minute * 60 def str: (Long) => String = Utils.msDurationToString(_) - val sep = new DecimalFormatSymbols(Locale.getDefault()).getDecimalSeparator() + val sep = new DecimalFormatSymbols(Locale.US).getDecimalSeparator assert(str(123) === "123 ms") assert(str(second) === "1" + sep + "0 s") diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 3066e9996abda..335ecb9320ab9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -49,9 +49,6 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { intercept[IllegalArgumentException] { new OpenHashMap[String, Int](-1) } - intercept[IllegalArgumentException] { - new OpenHashMap[String, String](0) - } } test("primitive value") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 2607a543dd614..210bc5c099742 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -176,4 +176,9 @@ class OpenHashSetSuite extends SparkFunSuite with Matchers { assert(set.size === 1000) assert(set.capacity > 1000) } + + test("SPARK-18200 Support zero as an initial set size") { + val set = new OpenHashSet[Long](0) + assert(set.size === 0) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index 508e737b725bc..f5ee428020fd4 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -49,9 +49,6 @@ class PrimitiveKeyOpenHashMapSuite extends SparkFunSuite with Matchers { intercept[IllegalArgumentException] { new PrimitiveKeyOpenHashMap[Int, Int](-1) } - intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](0) - } } test("basic operations") { diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala index 366ffda7788d3..d5956ea32096a 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/RadixSortSuite.scala @@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator} import scala.util.Random +import com.google.common.primitives.Ints + import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.LongArray @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter import org.apache.spark.util.random.XORShiftRandom class RadixSortSuite extends SparkFunSuite with Logging { - private val N = 10000 // scale this down for more readable results + private val N = 10000L // scale this down for more readable results /** * Describes a type of sort to test, e.g. two's complement descending. Each sort type has @@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging { }, 2, 4, false, false, true)) - private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = { - val ref = Array.tabulate[Long](size) { i => rand } - val extended = ref ++ Array.fill[Long](size)(0) + private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0) (ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = { - val ref = Array.tabulate[Long](size * 2) { i => rand } - val extended = ref ++ Array.fill[Long](size * 2)(0) + private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = { + val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand } + val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0) (new LongArray(MemoryBlock.fromLongArray(ref)), new LongArray(MemoryBlock.fromLongArray(extended))) } - private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = { + private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = { var i = 0 - val out = new Array[Long](length) + val out = new Array[Long](Ints.checkedCast(length)) while (i < length) { out(i) = array.get(offset + i) i += 1 @@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging { } } - private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) { + private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) { val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt))) new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort( - buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] { + buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] { override def compare( r1: RecordPointerAndKeyPrefix, - r2: RecordPointerAndKeyPrefix): Int = { - refCmp.compare(r1.keyPrefix, r2.keyPrefix) - } + r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix) }) } diff --git a/dev/checkstyle.xml b/dev/checkstyle.xml index 3de6aa91dcd51..92c5251c85037 100644 --- a/dev/checkstyle.xml +++ b/dev/checkstyle.xml @@ -52,6 +52,20 @@ + + + + + + + @@ -168,5 +182,6 @@ + diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 96f9b5714ebb8..1dbfa3b6e361b 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -162,14 +162,35 @@ if [[ "$1" == "package" ]]; then export ZINC_PORT=$ZINC_PORT echo "Creating distribution: $NAME ($FLAGS)" + # Write out the NAME and VERSION to PySpark version info we rewrite the - into a . and SNAPSHOT + # to dev0 to be closer to PEP440. We use the NAME as a "local version". + PYSPARK_VERSION=`echo "$SPARK_VERSION+$NAME" | sed -r "s/-/./" | sed -r "s/SNAPSHOT/dev0/"` + echo "__version__='$PYSPARK_VERSION'" > python/pyspark/version.py + # Get maven home set by MVN MVN_HOME=`$MVN -version 2>&1 | grep 'Maven home' | awk '{print $NF}'` - ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz $FLAGS \ + echo "Creating distribution" + ./dev/make-distribution.sh --name $NAME --mvn $MVN_HOME/bin/mvn --tgz --pip $FLAGS \ -DzincPort=$ZINC_PORT 2>&1 > ../binary-release-$NAME.log cd .. - cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . + echo "Copying and signing python distribution" + PYTHON_DIST_NAME=pyspark-$PYSPARK_VERSION.tar.gz + cp spark-$SPARK_VERSION-bin-$NAME/python/dist/$PYTHON_DIST_NAME . + + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ + --output $PYTHON_DIST_NAME.asc \ + --detach-sig $PYTHON_DIST_NAME + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + MD5 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.md5 + echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --print-md \ + SHA512 $PYTHON_DIST_NAME > \ + $PYTHON_DIST_NAME.sha + + echo "Copying and signing regular binary distribution" + cp spark-$SPARK_VERSION-bin-$NAME/spark-$SPARK_VERSION-bin-$NAME.tgz . echo $GPG_PASSPHRASE | $GPG --passphrase-fd 0 --armour \ --output spark-$SPARK_VERSION-bin-$NAME.tgz.asc \ --detach-sig spark-$SPARK_VERSION-bin-$NAME.tgz @@ -187,10 +208,10 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. FLAGS="-Psparkr -Phive -Phive-thriftserver -Pyarn -Pmesos" - make_binary_release "hadoop2.3" "-Phadoop2.3 $FLAGS" "3033" & - make_binary_release "hadoop2.4" "-Phadoop2.4 $FLAGS" "3034" & - make_binary_release "hadoop2.6" "-Phadoop2.6 $FLAGS" "3035" & - make_binary_release "hadoop2.7" "-Phadoop2.7 $FLAGS" "3036" & + make_binary_release "hadoop2.3" "-Phadoop-2.3 $FLAGS" "3033" & + make_binary_release "hadoop2.4" "-Phadoop-2.4 $FLAGS" "3034" & + make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" & + make_binary_release "hadoop2.7" "-Phadoop-2.7 $FLAGS" "3036" & make_binary_release "hadoop2.4-without-hive" "-Psparkr -Phadoop-2.4 -Pyarn -Pmesos" "3037" & make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn -Pmesos" "3038" & wait @@ -208,6 +229,7 @@ if [[ "$1" == "package" ]]; then # Re-upload a second time and leave the files in the timestamped upload directory: LFTP mkdir -p $dest_dir LFTP mput -O $dest_dir 'spark-*' + LFTP mput -O $dest_dir 'pyspark-*' exit 0 fi diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b7e5100ca7408..370a62ce15bc4 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -65,6 +65,7 @@ sed -i".tmp1" 's/Version.*$/Version: '"$RELEASE_VERSION"'/g' R/pkg/DESCRIPTION # Set the release version in docs sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml +sed -i".tmp3" 's/__version__ = .*$/__version__ = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" @@ -74,12 +75,16 @@ git tag $RELEASE_TAG $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs # Remove -SNAPSHOT before setting the R version as R expects version strings to only have numbers R_NEXT_VERSION=`echo $NEXT_VERSION | sed 's/-SNAPSHOT//g'` -sed -i".tmp2" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +sed -i".tmp4" 's/Version.*$/Version: '"$R_NEXT_VERSION"'/g' R/pkg/DESCRIPTION +# Write out the R_NEXT_VERSION to PySpark version info we use dev0 instead of SNAPSHOT to be closer +# to PEP440. +sed -i".tmp5" 's/__version__ = .*$/__version__ = "'"$R_NEXT_VERSION.dev0"'"/' python/pyspark/version.py + # Update docs with next version -sed -i".tmp3" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp6" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$NEXT_VERSION"'/g' docs/_config.yml # Use R version for short version -sed -i".tmp4" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml +sed -i".tmp7" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$R_NEXT_VERSION"'/g' docs/_config.yml git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index dfe4eb9f8bc65..b5695aa291920 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -117,7 +117,6 @@ jline-2.12.1.jar joda-time-2.9.3.jar jodd-core-3.5.2.jar jpam-1.1.jar -json-20090211.jar json4s-ast_2.11-3.2.11.jar json4s-core_2.11-3.2.11.jar json4s-jackson_2.11-3.2.11.jar @@ -141,7 +140,7 @@ metrics-jvm-3.1.2.jar minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar -netty-all-4.0.41.Final.jar +netty-all-4.0.42.Final.jar objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar diff --git a/dev/lint-python b/dev/lint-python index 63487043a50b6..3f878c2dad6b1 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,7 +20,9 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" -PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" +# TODO: fix pep8 errors with the rest of the Python scripts under dev +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/*.py ./dev/run-tests-jenkins.py" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/pip-sanity-check.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" PYLINT_INSTALL_INFO="$SPARK_ROOT_DIR/dev/pylint-info.txt" diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index 00b1ba201b145..b946397aff571 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -33,6 +33,7 @@ SPARK_HOME="$(cd "`dirname "$0"`/.."; pwd)" DISTDIR="$SPARK_HOME/dist" MAKE_TGZ=false +MAKE_PIP=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -40,7 +41,7 @@ function exit_with_usage { echo "make-distribution.sh - tool for making binary distributions of Spark" echo "" echo "usage:" - cl_options="[--name] [--tgz] [--mvn ]" + cl_options="[--name] [--tgz] [--pip] [--mvn ]" echo "make-distribution.sh $cl_options " echo "See Spark's \"Building Spark\" doc for correct Maven options." echo "" @@ -67,6 +68,9 @@ while (( "$#" )); do --tgz) MAKE_TGZ=true ;; + --pip) + MAKE_PIP=true + ;; --mvn) MVN="$2" shift @@ -201,6 +205,16 @@ fi # Copy data files cp -r "$SPARK_HOME/data" "$DISTDIR" +# Make pip package +if [ "$MAKE_PIP" == "true" ]; then + echo "Building python distribution package" + cd $SPARK_HOME/python + python setup.py sdist + cd .. +else + echo "Skipping creating pip installable PySpark" +fi + # Copy other things mkdir "$DISTDIR"/conf cp "$SPARK_HOME"/conf/*.template "$DISTDIR"/conf diff --git a/dev/pip-sanity-check.py b/dev/pip-sanity-check.py new file mode 100644 index 0000000000000..430c2ab52766a --- /dev/null +++ b/dev/pip-sanity-check.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark.sql import SparkSession +import sys + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("PipSanityCheck")\ + .getOrCreate() + sc = spark.sparkContext + rdd = sc.parallelize(range(100), 10) + value = rdd.reduce(lambda x, y: x + y) + if (value != 4950): + print("Value {0} did not match expected value.".format(value), file=sys.stderr) + sys.exit(-1) + print("Successfully ran pip sanity check") + + spark.stop() diff --git a/dev/run-pip-tests b/dev/run-pip-tests new file mode 100755 index 0000000000000..e1da18e60bb3d --- /dev/null +++ b/dev/run-pip-tests @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stop on error +set -e +# Set nullglob for when we are checking existence based on globs +shopt -s nullglob + +FWDIR="$(cd "$(dirname "$0")"/..; pwd)" +cd "$FWDIR" + +echo "Constucting virtual env for testing" +VIRTUALENV_BASE=$(mktemp -d) + +# Clean up the virtual env enviroment used if we created one. +function delete_virtualenv() { + echo "Cleaning up temporary directory - $VIRTUALENV_BASE" + rm -rf "$VIRTUALENV_BASE" +} +trap delete_virtualenv EXIT + +# Some systems don't have pip or virtualenv - in those cases our tests won't work. +if ! hash virtualenv 2>/dev/null; then + echo "Missing virtualenv skipping pip installability tests." + exit 0 +fi +if ! hash pip 2>/dev/null; then + echo "Missing pip, skipping pip installability tests." + exit 0 +fi + +# Figure out which Python execs we should test pip installation with +PYTHON_EXECS=() +if hash python2 2>/dev/null; then + # We do this since we are testing with virtualenv and the default virtual env python + # is in /usr/bin/python + PYTHON_EXECS+=('python2') +elif hash python 2>/dev/null; then + # If python2 isn't installed fallback to python if available + PYTHON_EXECS+=('python') +fi +if hash python3 2>/dev/null; then + PYTHON_EXECS+=('python3') +fi + +# Determine which version of PySpark we are building for archive name +PYSPARK_VERSION=$(python -c "exec(open('python/pyspark/version.py').read());print __version__") +PYSPARK_DIST="$FWDIR/python/dist/pyspark-$PYSPARK_VERSION.tar.gz" +# The pip install options we use for all the pip commands +PIP_OPTIONS="--upgrade --no-cache-dir --force-reinstall " +# Test both regular user and edit/dev install modes. +PIP_COMMANDS=("pip install $PIP_OPTIONS $PYSPARK_DIST" + "pip install $PIP_OPTIONS -e python/") + +for python in "${PYTHON_EXECS[@]}"; do + for install_command in "${PIP_COMMANDS[@]}"; do + echo "Testing pip installation with python $python" + # Create a temp directory for us to work in and save its name to a file for cleanup + echo "Using $VIRTUALENV_BASE for virtualenv" + VIRTUALENV_PATH="$VIRTUALENV_BASE"/$python + rm -rf "$VIRTUALENV_PATH" + mkdir -p "$VIRTUALENV_PATH" + virtualenv --python=$python "$VIRTUALENV_PATH" + source "$VIRTUALENV_PATH"/bin/activate + # Upgrade pip + pip install --upgrade pip + + echo "Creating pip installable source dist" + cd "$FWDIR"/python + $python setup.py sdist + + + echo "Installing dist into virtual env" + cd dist + # Verify that the dist directory only contains one thing to install + sdists=(*.tar.gz) + if [ ${#sdists[@]} -ne 1 ]; then + echo "Unexpected number of targets found in dist directory - please cleanup existing sdists first." + exit -1 + fi + # Do the actual installation + cd "$FWDIR" + $install_command + + cd / + + echo "Run basic sanity check on pip installed version with spark-submit" + spark-submit "$FWDIR"/dev/pip-sanity-check.py + echo "Run basic sanity check with import based" + python "$FWDIR"/dev/pip-sanity-check.py + echo "Run the tests for context.py" + python "$FWDIR"/python/pyspark/context.py + + cd "$FWDIR" + + done +done + +exit 0 diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index a48d918f9dc1f..1d1e72faccf2a 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -128,6 +128,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', + ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests', ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests', ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of \`%s\`' % ( tests_timeout) diff --git a/dev/run-tests.py b/dev/run-tests.py index 5d661f5f1a1c5..ab285ac96af7e 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -432,6 +432,12 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_python_packaging_tests(): + set_title_and_block("Running PySpark packaging tests", "BLOCK_PYSPARK_PIP_TESTS") + command = [os.path.join(SPARK_HOME, "dev", "run-pip-tests")] + run_cmd(command) + + def run_build_tests(): set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) @@ -583,6 +589,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: run_python_tests(modules_with_python_tests, opts.parallelism) + run_python_packaging_tests() if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 89015f8c4fb9c..38f25da41f775 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -33,5 +33,6 @@ "BLOCK_SPARKR_UNIT_TESTS": 20, "BLOCK_JAVA_STYLE": 21, "BLOCK_BUILD_TESTS": 22, + "BLOCK_PYSPARK_PIP_TESTS": 23, "BLOCK_TIMEOUT": 124 } diff --git a/docs/building-spark.md b/docs/building-spark.md index ebe46a42a15c6..88da0cc9c3bbf 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -13,6 +13,7 @@ redirect_from: "building-with-maven.html" The Maven-based build is the build of reference for Apache Spark. Building Spark using Maven requires Maven 3.3.9 or newer and Java 7+. +Note that support for Java 7 is deprecated as of Spark 2.0.0 and may be removed in Spark 2.2.0. ### Setting up Maven's Memory Usage @@ -79,6 +80,9 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro +Note that support for versions of Hadoop before 2.6 are deprecated as of Spark 2.1.0 and may be +removed in Spark 2.2.0. + You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. @@ -129,6 +133,8 @@ To produce a Spark package compiled with Scala 2.10, use the `-Dscala-2.10` prop ./dev/change-scala-version.sh 2.10 ./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package + +Note that support for Scala 2.10 is deprecated as of Spark 2.1.0 and may be removed in Spark 2.2.0. ## Building submodules individually @@ -259,6 +265,14 @@ or Java 8 tests are automatically enabled when a Java 8 JDK is detected. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. +## PySpark pip installable + +If you are building Spark for use in a Python environment and you wish to pip install it, you will first need to build the Spark JARs as described above. Then you can construct an sdist package suitable for setup.py and pip installable package. + + cd python; python setup.py sdist + +**Note:** Due to packaging requirements you can not directly pip install from the Python directory, rather you must first build the sdist package as described above. + ## PySpark Tests with Maven If you are building PySpark and wish to run the PySpark tests you will need to build Spark with Hive support. diff --git a/docs/configuration.md b/docs/configuration.md index 780fc94908d38..a3b4ff01e6d92 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -202,6 +202,15 @@ of the most common options to set are: or remotely ("cluster") on one of the nodes inside the cluster. + + spark.log.callerContext + (none) + + Application information that will be written into Yarn RM log/HDFS audit log when running on Yarn/HDFS. + Its length depends on the Hadoop configuration hadoop.caller.context.max.size. It should be concise, + and typically can have up to 50 characters. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -767,7 +776,7 @@ Apart from these, the following properties are also available, and may be useful spark.kryo.referenceTracking - true (false when using Spark SQL Thrift Server) + true Whether to track references to the same object when serializing data with Kryo, which is necessary if your object graphs have loops and useful for efficiency if they contain multiple @@ -838,8 +847,7 @@ Apart from these, the following properties are also available, and may be useful spark.serializer - org.apache.spark.serializer.
JavaSerializer (org.apache.spark.serializer.
- KryoSerializer when using Spark SQL Thrift Server) + org.apache.spark.serializer.
JavaSerializer Class to use for serializing objects that will be sent over the network or need to be cached @@ -1035,6 +1043,22 @@ Apart from these, the following properties are also available, and may be useful its contents do not match those of the source. + + spark.files.maxPartitionBytes + 134217728 (128 MB) + + The maximum number of bytes to pack into a single partition when reading files. + + + + spark.files.openCostInBytes + 4194304 (4 MB) + + The estimated cost to open a file, measured by the number of bytes could be scanned in the same + time. This is used when putting multiple files into a partition. It is better to over estimate, + then the partitions with small files will be faster than partitions with bigger files. + + spark.hadoop.cloneConf false @@ -1160,7 +1184,7 @@ Apart from these, the following properties are also available, and may be useful spark.rpc.askTimeout - 120s + spark.network.timeout Duration for an RPC ask operation to wait before timing out. @@ -1514,9 +1538,35 @@ Apart from these, the following properties are also available, and may be useful currently supported by the external shuffle service. + + spark.authenticate.encryption.aes.enabled + false + + Enable AES for over-the-wire encryption + + + + spark.authenticate.encryption.aes.cipher.keySize + 16 + + The bytes of AES cipher key which is effective when AES cipher is enabled. AES + works with 16, 24 and 32 bytes keys. + + + + spark.authenticate.encryption.aes.cipher.class + null + + Specify the underlying implementation class of crypto cipher. Set null here to use default. + In order to use OpenSslCipher users should install openssl. Currently, there are two cipher + classes available in Commons Crypto library: + org.apache.commons.crypto.cipher.OpenSslCipher + org.apache.commons.crypto.cipher.JceCipher + + spark.core.connection.ack.wait.timeout - 60s + spark.network.timeout How long for the connection to wait for ack to occur before timing out and giving up. To avoid unwilling timeout caused by long pause like GC, @@ -1901,7 +1951,7 @@ showDF(properties, numRows = 200, truncate = FALSE) spark.r.heartBeatInterval 100 - Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout. + Interval for heartbeats sent from SparkR backend to R process to prevent connection timeout. diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 58671e6f146d8..e271b28fb4f28 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -11,6 +11,7 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[VertexRDD]: api/scala/index.html#org.apache.spark.graphx.VertexRDD [Edge]: api/scala/index.html#org.apache.spark.graphx.Edge [EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet [Graph]: api/scala/index.html#org.apache.spark.graphx.Graph @@ -35,7 +36,6 @@ description: GraphX graph processing library guide for Spark SPARK_VERSION_SHORT [Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] [Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] [PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] [PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ [ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ [TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ @@ -89,7 +89,7 @@ with user defined objects attached to each vertex and edge. A directed multigra graph with potentially multiple parallel edges sharing the same source and destination vertex. The ability to support parallel edges simplifies modeling scenarios where there can be multiple relationships (e.g., co-worker and friend) between the same vertices. Each vertex is keyed by a -*unique* 64-bit long identifier (`VertexID`). GraphX does not impose any ordering constraints on +*unique* 64-bit long identifier (`VertexId`). GraphX does not impose any ordering constraints on the vertex identifiers. Similarly, edges have corresponding source and destination vertex identifiers. @@ -130,12 +130,12 @@ class Graph[VD, ED] { } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexId, VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the -`VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge +`VertexRDD`[VertexRDD] and `EdgeRDD`[EdgeRDD] API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: -`RDD[(VertexID, VD)]` and `RDD[Edge[ED]]`. +`RDD[(VertexId, VD)]` and `RDD[Edge[ED]]`. ### Example Property Graph @@ -197,7 +197,7 @@ graph.edges.filter(e => e.srcId > e.dstId).count {% endhighlight %} > Note that `graph.vertices` returns an `VertexRDD[(String, String)]` which extends -> `RDD[(VertexID, (String, String))]` and so we use the scala `case` expression to deconstruct the +> `RDD[(VertexId, (String, String))]` and so we use the scala `case` expression to deconstruct the > tuple. On the other hand, `graph.edges` returns an `EdgeRDD` containing `Edge[String]` objects. > We could have also used the case class type constructor as in the following: > {% highlight scala %} @@ -287,7 +287,7 @@ class Graph[VD, ED] { // Change the partitioning heuristic ============================================================ def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] // Transform vertex and edge attributes ========================================================== - def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2](map: (VertexId, VD) => VD2): Graph[VD2, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapEdges[ED2](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] @@ -297,18 +297,18 @@ class Graph[VD, ED] { def reverse: Graph[VD, ED] def subgraph( epred: EdgeTriplet[VD,ED] => Boolean = (x => true), - vpred: (VertexID, VD) => Boolean = ((v, d) => true)) + vpred: (VertexId, VD) => Boolean = ((v, d) => true)) : Graph[VD, ED] def mask[VD2, ED2](other: Graph[VD2, ED2]): Graph[VD, ED] def groupEdges(merge: (ED, ED) => ED): Graph[VD, ED] // Join RDDs with the graph ====================================================================== - def joinVertices[U](table: RDD[(VertexID, U)])(mapFunc: (VertexID, VD, U) => VD): Graph[VD, ED] - def outerJoinVertices[U, VD2](other: RDD[(VertexID, U)]) - (mapFunc: (VertexID, VD, Option[U]) => VD2) + def joinVertices[U](table: RDD[(VertexId, U)])(mapFunc: (VertexId, VD, U) => VD): Graph[VD, ED] + def outerJoinVertices[U, VD2](other: RDD[(VertexId, U)]) + (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] // Aggregate information about adjacent triplets ================================================= - def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] - def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] + def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] + def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] def aggregateMessages[Msg: ClassTag]( sendMsg: EdgeContext[VD, ED, Msg] => Unit, mergeMsg: (Msg, Msg) => Msg, @@ -316,15 +316,15 @@ class Graph[VD, ED] { : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( - vprog: (VertexID, VD, A) => VD, - sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexID,A)], + vprog: (VertexId, VD, A) => VD, + sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId,A)], mergeMsg: (A, A) => A) : Graph[VD, ED] // Basic graph algorithms ======================================================================== def pageRank(tol: Double, resetProb: Double = 0.15): Graph[Double, Double] - def connectedComponents(): Graph[VertexID, ED] + def connectedComponents(): Graph[VertexId, ED] def triangleCount(): Graph[Int, ED] - def stronglyConnectedComponents(numIter: Int): Graph[VertexID, ED] + def stronglyConnectedComponents(numIter: Int): Graph[VertexId, ED] } {% endhighlight %} @@ -481,7 +481,7 @@ original value. > is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} -val nonUniqueCosts: RDD[(VertexID, Double)] +val nonUniqueCosts: RDD[(VertexId, Double)] val uniqueCosts: VertexRDD[Double] = graph.vertices.aggregateUsingIndex(nonUnique, (a,b) => a + b) val joinedGraph = graph.joinVertices(uniqueCosts)( @@ -511,7 +511,7 @@ val degreeGraph = graph.outerJoinVertices(outDegrees) { (id, oldAttr, outDegOpt) > provide type annotation for the user defined function: > {% highlight scala %} val joinedGraph = graph.joinVertices(uniqueCosts, - (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) + (id: VertexId, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} > @@ -558,7 +558,7 @@ The user defined `mergeMsg` function takes two messages destined to the same ver yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not -receive a message are not included in the returned `VertexRDD`. +receive a message are not included in the returned `VertexRDD`[VertexRDD]. + +More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.LinearRegression). + {% include_example python/ml/linear_regression_with_elastic_net.py %}
@@ -511,7 +538,7 @@ others. -**Example** +**Examples** The following example demonstrates training a GLM with a Gaussian response and identity link function and extracting model summary statistics. @@ -519,18 +546,21 @@ function and extracting model summary statistics.
+ Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example scala/org/apache/spark/examples/ml/GeneralizedLinearRegressionExample.scala %}
+ Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/GeneralizedLinearRegression.html) for more details. {% include_example java/org/apache/spark/examples/ml/JavaGeneralizedLinearRegressionExample.java %}
+ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.GeneralizedLinearRegression) for more details. {% include_example python/ml/generalized_linear_regression_example.py %} @@ -544,7 +574,7 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. Decision trees are a popular family of classification and regression methods. More information about the `spark.ml` implementation can be found further in the [section on decision trees](#decision-trees). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the Decision Tree algorithm can recognize. @@ -579,7 +609,7 @@ More details on parameters can be found in the [Python API documentation](api/py Random forests are a popular family of classification and regression methods. More information about the `spark.ml` implementation can be found further in the [section on random forests](#random-forests). -**Example** +**Examples** The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the `DataFrame` which the tree-based algorithms can recognize. @@ -612,7 +642,7 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression. Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. More information about the `spark.ml` implementation can be found further in the [section on GBTs](#gradient-boosted-trees-gbts). -**Example** +**Examples** Note: For this example dataset, `GBTRegressor` actually only needs 1 iteration, but that will not be true in general. @@ -700,19 +730,28 @@ The implementation matches the result from R's survival function > When fitting AFTSurvivalRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is different from R survival::survreg. -**Example** +**Examples**
+ +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %}
+ +Refer to the [Java API docs](api/java/org/apache/spark/ml/regression/AFTSurvivalRegression.html) for more details. + {% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %}
+ +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.regression.AFTSurvivalRegression) for more details. + {% include_example python/ml/aft_survival_regression.py %}
@@ -765,7 +804,7 @@ is treated as piecewise linear function. The rules for prediction therefore are: predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. -### Examples +**Examples**
@@ -945,7 +984,7 @@ Random forests combine many decision trees in order to reduce the risk of overfi The `spark.ml` implementation supports random forests for binary and multiclass classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on random forests](mllib-ensembles.html#random-forests). ### Inputs and Outputs @@ -1026,7 +1065,7 @@ GBTs iteratively train decision trees in order to minimize a loss function. The `spark.ml` implementation supports GBTs for binary classification and for regression, using both continuous and categorical features. -For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html). +For more information on the algorithm itself, please see the [`spark.mllib` documentation on GBTs](mllib-ensembles.html#gradient-boosted-trees-gbts). ### Inputs and Outputs diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index 8a0a61cb595e7..eedacb12bc46f 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -65,7 +65,7 @@ called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). -### Example +**Examples**
@@ -94,6 +94,8 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering. and generates a `LDAModel` as the base model. Expert users may cast a `LDAModel` generated by `EMLDAOptimizer` to a `DistributedLDAModel` if needed. +**Examples** +
@@ -128,7 +130,7 @@ Bisecting K-means can often be much faster than regular K-means, but it will gen `BisectingKMeans` is implemented as an `Estimator` and generates a `BisectingKMeansModel` as the base model. -### Example +**Examples**
@@ -210,7 +212,7 @@ model. -### Example +**Examples**
diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index 1d02d6933cb48..4d19b4069a1f2 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -59,7 +59,7 @@ This approach is named "ALS-WR" and discussed in the paper It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. -## Examples +**Examples**
diff --git a/docs/ml-features.md b/docs/ml-features.md index 64c6a160239cc..d2f036fb083da 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -112,6 +112,8 @@ can then be used as features for prediction, document similarity calculations, e Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#word2vec) for more details. +**Examples** + In the following code segment, we start with a set of documents, each of which is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
@@ -220,6 +222,8 @@ for more details on the API. Alternatively, users can set parameter "gaps" to false indicating the regex "pattern" denotes "tokens" rather than splitting gaps, and find all matching occurrences as the tokenization result. +**Examples** +
@@ -321,6 +325,8 @@ An [n-gram](https://en.wikipedia.org/wiki/N-gram) is a sequence of $n$ tokens (t `NGram` takes as input a sequence of strings (e.g. the output of a [Tokenizer](ml-features.html#tokenizer)). The parameter `n` is used to determine the number of terms in each $n$-gram. The output will consist of a sequence of $n$-grams where each $n$-gram is represented by a space-delimited string of $n$ consecutive words. If the input sequence contains fewer than `n` strings, no output is produced. +**Examples** +
@@ -358,6 +364,8 @@ for binarization. Feature values greater than the threshold are binarized to 1.0 to or less than the threshold are binarized to 0.0. Both Vector and Double types are supported for `inputCol`. +**Examples** +
@@ -388,6 +396,8 @@ for more details on the API. [PCA](http://en.wikipedia.org/wiki/Principal_component_analysis) is a statistical procedure that uses an orthogonal transformation to convert a set of observations of possibly correlated variables into a set of values of linearly uncorrelated variables called principal components. A [PCA](api/scala/index.html#org.apache.spark.ml.feature.PCA) class trains a model to project vectors to a low-dimensional space using PCA. The example below shows how to project 5-dimensional feature vectors into 3-dimensional principal components. +**Examples** +
@@ -418,6 +428,8 @@ for more details on the API. [Polynomial expansion](http://en.wikipedia.org/wiki/Polynomial_expansion) is the process of expanding your features into a polynomial space, which is formulated by an n-degree combination of original dimensions. A [PolynomialExpansion](api/scala/index.html#org.apache.spark.ml.feature.PolynomialExpansion) class provides this functionality. The example below shows how to expand your features into a 3-degree polynomial space. +**Examples** +
@@ -458,6 +470,8 @@ for the transform is unitary. No shift is applied to the transformed sequence (e.g. the $0$th element of the transformed sequence is the $0$th DCT coefficient and _not_ the $N/2$th). +**Examples** +
@@ -663,6 +677,8 @@ for more details on the API. [One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features. +**Examples** +
@@ -694,13 +710,15 @@ for more details on the API. `VectorIndexer` helps index categorical features in datasets of `Vector`s. It can both automatically decide which features are categorical and convert original values to category indices. Specifically, it does the following: -1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) and a parameter `maxCategories`. +1. Take an input column of type [Vector](api/scala/index.html#org.apache.spark.ml.linalg.Vector) and a parameter `maxCategories`. 2. Decide which features should be categorical based on the number of distinct values, where features with at most `maxCategories` are declared categorical. 3. Compute 0-based category indices for each categorical feature. 4. Index categorical features and transform original feature values to indices. Indexing categorical features allows algorithms such as Decision Trees and Tree Ensembles to treat categorical features appropriately, improving performance. +**Examples** + In the example below, we read in a dataset of labeled points and then use `VectorIndexer` to decide which features should be treated as categorical. We transform the categorical feature values to their indices. This transformed data could then be passed to algorithms such as `DecisionTreeRegressor` that handle categorical features.
@@ -729,11 +747,65 @@ for more details on the API.
+## Interaction + +`Interaction` is a `Transformer` which takes vector or double-valued columns, and generates a single vector column that contains the product of all combinations of one value from each input column. + +For example, if you have 2 vector type columns each of which has 3 dimensions as input columns, then then you'll get a 9-dimensional vector as the output column. + +**Examples** + +Assume that we have the following DataFrame with the columns "id1", "vec1", and "vec2": + +~~~~ + id1|vec1 |vec2 + ---|--------------|-------------- + 1 |[1.0,2.0,3.0] |[8.0,4.0,5.0] + 2 |[4.0,3.0,8.0] |[7.0,9.0,8.0] + 3 |[6.0,1.0,9.0] |[2.0,3.0,6.0] + 4 |[10.0,8.0,6.0]|[9.0,4.0,5.0] + 5 |[9.0,2.0,7.0] |[10.0,7.0,3.0] + 6 |[1.0,1.0,4.0] |[2.0,8.0,4.0] +~~~~ + +Applying `Interaction` with those input columns, +then `interactedCol` as the output column contains: + +~~~~ + id1|vec1 |vec2 |interactedCol + ---|--------------|--------------|------------------------------------------------------ + 1 |[1.0,2.0,3.0] |[8.0,4.0,5.0] |[8.0,4.0,5.0,16.0,8.0,10.0,24.0,12.0,15.0] + 2 |[4.0,3.0,8.0] |[7.0,9.0,8.0] |[56.0,72.0,64.0,42.0,54.0,48.0,112.0,144.0,128.0] + 3 |[6.0,1.0,9.0] |[2.0,3.0,6.0] |[36.0,54.0,108.0,6.0,9.0,18.0,54.0,81.0,162.0] + 4 |[10.0,8.0,6.0]|[9.0,4.0,5.0] |[360.0,160.0,200.0,288.0,128.0,160.0,216.0,96.0,120.0] + 5 |[9.0,2.0,7.0] |[10.0,7.0,3.0]|[450.0,315.0,135.0,100.0,70.0,30.0,350.0,245.0,105.0] + 6 |[1.0,1.0,4.0] |[2.0,8.0,4.0] |[12.0,48.0,24.0,12.0,48.0,24.0,48.0,192.0,96.0] +~~~~ + +
+
+ +Refer to the [Interaction Scala docs](api/scala/index.html#org.apache.spark.ml.feature.Interaction) +for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/InteractionExample.scala %} +
+ +
+ +Refer to the [Interaction Java docs](api/java/org/apache/spark/ml/feature/Interaction.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaInteractionExample.java %} +
+
## Normalizer `Normalizer` is a `Transformer` which transforms a dataset of `Vector` rows, normalizing each `Vector` to have unit norm. It takes parameter `p`, which specifies the [p-norm](http://en.wikipedia.org/wiki/Norm_%28mathematics%29#p-norm) used for normalization. ($p = 2$ by default.) This normalization can help standardize your input data and improve the behavior of learning algorithms. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then normalize each row to have unit $L^1$ norm and unit $L^\infty$ norm.
@@ -774,6 +846,8 @@ for more details on the API. Note that if the standard deviation of a feature is zero, it will return default `0.0` value in the `Vector` for that feature. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then normalize each feature to have unit standard deviation.
@@ -819,6 +893,8 @@ For the case `$E_{max} == E_{min}$`, `$Rescaled(e_i) = 0.5 * (max + min)$` Note that since zero values will probably be transformed to non-zero values, output of the transformer will be `DenseVector` even for sparse input. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [0, 1].
@@ -860,6 +936,8 @@ data, and thus does not destroy any sparsity. `MaxAbsScaler` computes summary statistics on a data set and produces a `MaxAbsScalerModel`. The model can then transform each feature individually to range [-1, 1]. +**Examples** + The following example demonstrates how to load a dataset in libsvm format and then rescale each feature to [-1, 1].
@@ -903,6 +981,8 @@ Note also that the splits that you provided have to be in strictly increasing or More details can be found in the API docs for [Bucketizer](api/scala/index.html#org.apache.spark.ml.feature.Bucketizer). +**Examples** + The following example demonstrates how to bucketize a column of `Double`s into another index-wised column.
@@ -951,6 +1031,8 @@ v_N \end{pmatrix} \]` +**Examples** + This example below demonstrates how to transform vectors using a transforming vector value.
@@ -1338,14 +1420,14 @@ for more details on the API. `ChiSqSelector` stands for Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. **Examples** diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index adb057ba7e250..0384513ab7014 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -38,26 +38,26 @@ algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Pipelines API, where the pipeline concept is mostly inspired by the [scikit-learn](http://scikit-learn.org/) project. -* **[`DataFrame`](ml-guide.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML +* **[`DataFrame`](ml-pipeline.html#dataframe)**: This ML API uses `DataFrame` from Spark SQL as an ML dataset, which can hold a variety of data types. E.g., a `DataFrame` could have different columns storing text, feature vectors, true labels, and predictions. -* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. +* **[`Transformer`](ml-pipeline.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`. E.g., an ML model is a `Transformer` which transforms a `DataFrame` with features into a `DataFrame` with predictions. -* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. +* **[`Estimator`](ml-pipeline.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`. E.g., a learning algorithm is an `Estimator` which trains on a `DataFrame` and produces a model. -* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. +* **[`Pipeline`](ml-pipeline.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow. -* **[`Parameter`](ml-guide.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. +* **[`Parameter`](ml-pipeline.html#parameters)**: All `Transformer`s and `Estimator`s now share a common API for specifying parameters. ## DataFrame Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data. This API adopts the `DataFrame` from Spark SQL in order to support a variety of data types. -`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types. +`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#data-types) for a list of supported types. In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](mllib-data-types.html#local-vector) types. A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples. @@ -207,14 +207,29 @@ This example covers the concepts of `Estimator`, `Transformer`, and `Param`.
+ +Refer to the [`Estimator` Scala docs](api/scala/index.html#org.apache.spark.ml.Estimator), +the [`Transformer` Scala docs](api/scala/index.html#org.apache.spark.ml.Transformer) and +the [`Params` Scala docs](api/scala/index.html#org.apache.spark.ml.param.Params) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala %}
+ +Refer to the [`Estimator` Java docs](api/java/org/apache/spark/ml/Estimator.html), +the [`Transformer` Java docs](api/java/org/apache/spark/ml/Transformer.html) and +the [`Params` Java docs](api/java/org/apache/spark/ml/param/Params.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaEstimatorTransformerParamExample.java %}
+ +Refer to the [`Estimator` Python docs](api/python/pyspark.ml.html#pyspark.ml.Estimator), +the [`Transformer` Python docs](api/python/pyspark.ml.html#pyspark.ml.Transformer) and +the [`Params` Python docs](api/python/pyspark.ml.html#pyspark.ml.param.Params) for more details on the API. + {% include_example python/ml/estimator_transformer_param_example.py %}
@@ -227,14 +242,24 @@ This example follows the simple text document `Pipeline` illustrated in the figu
+ +Refer to the [`Pipeline` Scala docs](api/scala/index.html#org.apache.spark.ml.Pipeline) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/PipelineExample.scala %}
+ + +Refer to the [`Pipeline` Java docs](api/java/org/apache/spark/ml/Pipeline.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaPipelineExample.java %}
+ +Refer to the [`Pipeline` Python docs](api/python/pyspark.ml.html#pyspark.ml.Pipeline) for more details on the API. + {% include_example python/ml/pipeline_example.py %}
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md index 2ca90c7092fd3..a135adc4334cc 100644 --- a/docs/ml-tuning.md +++ b/docs/ml-tuning.md @@ -62,7 +62,7 @@ To help construct the parameter grid, users can use the [`ParamGridBuilder`](api After identifying the best `ParamMap`, `CrossValidator` finally re-fits the `Estimator` using the best `ParamMap` and the entire dataset. -## Example: model selection via cross-validation +**Examples: model selection via cross-validation** The following example demonstrates using `CrossValidator` to select from a grid of parameters. @@ -75,15 +75,23 @@ However, it is also a well-established method for choosing parameters which is m
+ +Refer to the [`CrossValidator` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.CrossValidator) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala %}
+ +Refer to the [`CrossValidator` Java docs](api/java/org/apache/spark/ml/tuning/CrossValidator.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java %}
+Refer to the [`CrossValidator` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.CrossValidator) for more details on the API. + {% include_example python/ml/cross_validator.py %}
@@ -102,19 +110,28 @@ It splits the dataset into these two parts using the `trainRatio` parameter. For Like `CrossValidator`, `TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset. -## Example: model selection via train validation split +**Examples: model selection via train validation split**
+ +Refer to the [`TrainValidationSplit` Scala docs](api/scala/index.html#org.apache.spark.ml.tuning.TrainValidationSplit) for details on the API. + {% include_example scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala %}
+ +Refer to the [`TrainValidationSplit` Java docs](api/java/org/apache/spark/ml/tuning/TrainValidationSplit.html) for details on the API. + {% include_example java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java %}
+ +Refer to the [`TrainValidationSplit` Python docs](api/python/pyspark.ml.html#pyspark.ml.tuning.TrainValidationSplit) for more details on the API. + {% include_example python/ml/train_validation_split.py %}
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d5f6ae379a85e..8990e95796b67 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -24,13 +24,11 @@ variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). The implementation in `spark.mllib` has the following parameters: -* *k* is the number of desired clusters. +* *k* is the number of desired clusters. Note that it is possible for fewer than k clusters to be returned, for example, if there are fewer than k distinct points to cluster. * *maxIterations* is the maximum number of iterations to run. * *initializationMode* specifies either random initialization or initialization via k-means\|\|. -* *runs* is the number of times to run the k-means algorithm (k-means is not -guaranteed to find a globally optimal solution, and when run multiple times on -a given dataset, the algorithm returns the best clustering result). +* *runs* This param has no effect since Spark 2.0.0. * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. * *initialModel* is an optional set of cluster centers used for initialization. If this parameter is supplied, only one run is performed. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 87e1e027e945b..42568c312e70e 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -227,22 +227,19 @@ both speed and statistical learning behavior. [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) implements Chi-Squared feature selection. It operates on labeled data with categorical features. ChiSqSelector uses the [Chi-Squared test of independence](https://en.wikipedia.org/wiki/Chi-squared_test) to decide which -features to choose. It supports three selection methods: `KBest`, `Percentile` and `FPR`: +features to choose. It supports three selection methods: `numTopFeatures`, `percentile`, `fpr`: -* `KBest` chooses the `k` top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. -* `Percentile` is similar to `KBest` but chooses a fraction of all features instead of a fixed number. -* `FPR` chooses all features whose false positive rate meets some threshold. +* `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. This is akin to yielding the features with the most predictive power. +* `percentile` is similar to `numTopFeatures` but chooses a fraction of all features instead of a fixed number. +* `fpr` chooses all features whose p-value is below a threshold, thus controlling the false positive rate of selection. -By default, the selection method is `KBest`, the default number of top features is 50. User can use -`setNumTopFeatures`, `setPercentile` and `setAlpha` to set different selection methods. +By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. +The user can choose a selection method using `setSelectorType`. The number of features to select can be tuned using a held-out validation set. ### Model Fitting -`ChiSqSelector` takes a `numTopFeatures` parameter specifying the number of top features that -the selector will select. - The [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method takes an input of `RDD[LabeledPoint]` with categorical features, learns the summary statistics, and then returns a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md index d90905a86ade9..ca84551506b2b 100644 --- a/docs/mllib-isotonic-regression.md +++ b/docs/mllib-isotonic-regression.md @@ -27,7 +27,7 @@ best fitting the original data points. [pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111) which uses an approach to [parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10). -The training input is a RDD of tuples of three double values that represent +The training input is an RDD of tuples of three double values that represent label, feature and weight in this order. Additionally IsotonicRegression algorithm has one optional parameter called $isotonic$ defaulting to true. This argument specifies if the isotonic regression is diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 816bdf1317000..3085539b40e61 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -139,7 +139,7 @@ and logistic regression. Linear SVMs supports only binary classification, while logistic regression supports both binary and multiclass classification problems. For both methods, `spark.mllib` supports L1 and L2 regularized variants. -The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html) in MLlib, +The training data set is represented by an RDD of [LabeledPoint](mllib-data-types.html#labeled-point) in MLlib, where labels are class indices starting from zero: $0, 1, 2, \ldots$. ### Linear Support Vector Machines (SVMs) @@ -491,5 +491,3 @@ Algorithms are all implemented in Scala: * [RidgeRegressionWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD) * [LassoWithSGD](api/scala/index.html#org.apache.spark.mllib.regression.LassoWithSGD) -Python calls the Scala implementation via -[PythonMLLibAPI](api/scala/index.html#org.apache.spark.mllib.api.python.PythonMLLibAPI). diff --git a/docs/monitoring.md b/docs/monitoring.md index 5bc5e18c4d45f..2eef4568d00e9 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -41,7 +41,7 @@ directory must be supplied in the `spark.history.fs.logDirectory` configuration and should contain sub-directories that each represents an application's event logs. The spark jobs themselves must be configured to log events, and to log them to the same shared, -writeable directory. For example, if the server was configured with a log directory of +writable directory. For example, if the server was configured with a log directory of `hdfs://namenode/shared/spark-logs`, then the client-side options would be: ``` diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 7516579ec6dbf..58bf17b4a84ef 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -59,6 +59,8 @@ Spark {{site.SPARK_VERSION}} works with Java 7 and higher. If you are using Java for concisely writing functions, otherwise you can use the classes in the [org.apache.spark.api.java.function](api/java/index.html?org/apache/spark/api/java/function/package-summary.html) package. +Note that support for Java 7 is deprecated as of Spark 2.0.0 and may be removed in Spark 2.2.0. + To write a Spark application in Java, you need to add a dependency on Spark. Spark is available through Maven Central at: groupId = org.apache.spark @@ -87,6 +89,8 @@ import org.apache.spark.SparkConf Spark {{site.SPARK_VERSION}} works with Python 2.6+ or Python 3.4+. It can use the standard CPython interpreter, so C libraries like NumPy can be used. It also works with PyPy 2.3+. +Note that support for Python 2.6 is deprecated as of Spark 2.0.0, and may be removed in Spark 2.2.0. + To run Spark applications in Python, use the `bin/spark-submit` script located in the Spark directory. This script will load Spark's Java/Scala libraries and allow you to submit applications to a cluster. You can also use `bin/pyspark` to launch an interactive Python shell. @@ -339,7 +343,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Scala API also supports several other data formats: @@ -371,7 +375,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Java API also supports several other data formats: @@ -403,7 +407,7 @@ Some notes on reading files with Spark: * All of Spark's file-based input methods, including `textFile`, support running on directories, compressed files, and wildcards as well. For example, you can use `textFile("/my/directory")`, `textFile("/my/directory/*.txt")`, and `textFile("/my/directory/*.gz")`. -* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. +* The `textFile` method also takes an optional second argument for controlling the number of partitions of the file. By default, Spark creates one partition for each block of the file (blocks being 128MB by default in HDFS), but you can also ask for a higher number of partitions by passing a larger value. Note that you cannot have fewer partitions than blocks. Apart from text files, Spark's Python API also supports several other data formats: diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 923d8dbebf3d2..8d5ad12cb85be 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -367,17 +367,6 @@ See the [configuration page](configuration.html) for information on Spark config
[host_path:]container_path[:ro|:rw]
- - spark.mesos.executor.docker.portmaps - (none) - - Set the list of incoming ports exposed by the Docker image, which was set using - spark.mesos.executor.docker.image. The format of this property is a comma-separated list of - mappings which take the form: - -
host_port:container_port[:tcp|:udp]
- - spark.mesos.executor.home driver side SPARK_HOME @@ -505,12 +494,26 @@ See the [configuration page](configuration.html) for information on Spark config Set the maximum number GPU resources to acquire for this job. Note that executors will still launch when no GPU resources are found since this configuration is just a upper limit and not a guaranteed amount. + + + spark.mesos.network.name + (none) + + Attach containers to the given named network. If this job is + launched in cluster mode, also launch the driver in the given named + network. See + the Mesos CNI docs + for more details. + spark.mesos.fetcherCache.enable false - If set to `true`, all URIs (example: `spark.executor.uri`, `spark.mesos.uris`) will be cached by the [Mesos fetcher cache](http://mesos.apache.org/documentation/latest/fetcher/) + If set to `true`, all URIs (example: `spark.executor.uri`, + `spark.mesos.uris`) will be cached by the Mesos + Fetcher Cache diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index cd18808681ece..4d1fafc07b8fc 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -117,28 +117,6 @@ To use a custom metrics.properties for the application master and executors, upd Use lower-case suffixes, e.g. k, m, g, t, and p, for kibi-, mebi-, gibi-, tebi-, and pebibytes, respectively. - - spark.driver.memory - 1g - - Amount of memory to use for the driver process, i.e. where SparkContext is initialized. - (e.g. 1g, 2g). - -
Note: In client mode, this config must not be set through the SparkConf - directly in your application, because the driver JVM has already started at that point. - Instead, please set this through the --driver-memory command line option - or in your default properties file. - - - - spark.driver.cores - 1 - - Number of cores used by the driver in YARN cluster mode. - Since the driver is run in the same JVM as the YARN Application Master in cluster mode, this also controls the cores used by the YARN Application Master. - In client mode, use spark.yarn.am.cores to control the number of cores used by the YARN Application Master instead. - - spark.yarn.am.cores 1 @@ -233,13 +211,6 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. - - spark.executor.cores - 1 in YARN mode, all the available cores on the worker in standalone mode. - - The number of cores to use on each executor. For YARN and standalone mode only. - - spark.executor.instances 2 @@ -247,13 +218,6 @@ To use a custom metrics.properties for the application master and executors, upd The number of executors for static allocation. With spark.dynamicAllocation.enabled, the initial set of executors will be at least this large. - - spark.executor.memory - 1g - - Amount of memory to use per executor process (e.g. 2g, 8g). - - spark.yarn.executor.memoryOverhead executorMemory * 0.10, with minimum of 384 @@ -559,6 +523,8 @@ pre-packaged distribution. 1. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to `org.apache.spark.network.yarn.YarnShuffleService`. +1. Increase `NodeManager's` heap size by setting `YARN_HEAPSIZE` (1000 by default) in `etc/hadoop/yarn-env.sh` +to avoid garbage collection issues during shuffle. 1. Restart all `NodeManager`s in your cluster. The following extra configuration options are available when the shuffle service is running on YARN: diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b9be7a7545ef8..656e7ecdab0bb 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -222,9 +222,9 @@ The `sql` function enables applications to run SQL queries programmatically and ## Global Temporary View -Temporay views in Spark SQL are session-scoped and will disappear if the session that creates it +Temporary views in Spark SQL are session-scoped and will disappear if the session that creates it terminates. If you want to have a temporary view that is shared among all sessions and keep alive -until the Spark application terminiates, you can create a global temporary view. Global temporary +until the Spark application terminates, you can create a global temporary view. Global temporary view is tied to a system preserved database `global_temp`, and we must use the qualified name to refer it, e.g. `SELECT * FROM global_temp.view1`. @@ -1029,7 +1029,7 @@ following command: bin/spark-shell --driver-class-path postgresql-9.4.1207.jar --jars postgresql-9.4.1207.jar {% endhighlight %} -Tables from the remote database can be loaded as a DataFrame or Spark SQL Temporary table using +Tables from the remote database can be loaded as a DataFrame or Spark SQL temporary view using the Data Sources API. Users can specify the JDBC connection properties in the data source options. user and password are normally provided as connection properties for logging into the data sources. In addition to the connection properties, Spark also supports @@ -1086,6 +1086,13 @@ the following case-sensitive options: + + maxConnections + + The maximum number of concurrent JDBC connections that can be used, if set. Only applies when writing. It works by limiting the operation's parallelism, which depends on the input's partition count. If its partition count exceeds this limit, the operation will coalesce the input to fewer partitions before writing. + + + isolationLevel diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index c1ef396907db7..b645d3c3a4b53 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -17,69 +17,72 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea
- import org.apache.kafka.clients.consumer.ConsumerRecord - import org.apache.kafka.common.serialization.StringDeserializer - import org.apache.spark.streaming.kafka010._ - import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent - import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe - - val kafkaParams = Map[String, Object]( - "bootstrap.servers" -> "localhost:9092,anotherhost:9092", - "key.deserializer" -> classOf[StringDeserializer], - "value.deserializer" -> classOf[StringDeserializer], - "group.id" -> "use_a_separate_group_id_for_each_stream", - "auto.offset.reset" -> "latest", - "enable.auto.commit" -> (false: java.lang.Boolean) - ) - - val topics = Array("topicA", "topicB") - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Subscribe[String, String](topics, kafkaParams) - ) - - stream.map(record => (record.key, record.value)) - +{% highlight scala %} +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.spark.streaming.kafka010._ +import org.apache.spark.streaming.kafka010.LocationStrategies.PreferConsistent +import org.apache.spark.streaming.kafka010.ConsumerStrategies.Subscribe + +val kafkaParams = Map[String, Object]( + "bootstrap.servers" -> "localhost:9092,anotherhost:9092", + "key.deserializer" -> classOf[StringDeserializer], + "value.deserializer" -> classOf[StringDeserializer], + "group.id" -> "use_a_separate_group_id_for_each_stream", + "auto.offset.reset" -> "latest", + "enable.auto.commit" -> (false: java.lang.Boolean) +) + +val topics = Array("topicA", "topicB") +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Subscribe[String, String](topics, kafkaParams) +) + +stream.map(record => (record.key, record.value)) +{% endhighlight %} Each item in the stream is a [ConsumerRecord](http://kafka.apache.org/0100/javadoc/org/apache/kafka/clients/consumer/ConsumerRecord.html)
- import java.util.*; - import org.apache.spark.SparkConf; - import org.apache.spark.TaskContext; - import org.apache.spark.api.java.*; - import org.apache.spark.api.java.function.*; - import org.apache.spark.streaming.api.java.*; - import org.apache.spark.streaming.kafka010.*; - import org.apache.kafka.clients.consumer.ConsumerRecord; - import org.apache.kafka.common.TopicPartition; - import org.apache.kafka.common.serialization.StringDeserializer; - import scala.Tuple2; - - Map kafkaParams = new HashMap<>(); - kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); - kafkaParams.put("key.deserializer", StringDeserializer.class); - kafkaParams.put("value.deserializer", StringDeserializer.class); - kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); - kafkaParams.put("auto.offset.reset", "latest"); - kafkaParams.put("enable.auto.commit", false); - - Collection topics = Arrays.asList("topicA", "topicB"); - - final JavaInputDStream> stream = - KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Subscribe(topics, kafkaParams) - ); - - stream.mapToPair( - new PairFunction, String, String>() { - @Override - public Tuple2 call(ConsumerRecord record) { - return new Tuple2<>(record.key(), record.value()); - } - }) +{% highlight java %} +import java.util.*; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.*; +import org.apache.spark.streaming.api.java.*; +import org.apache.spark.streaming.kafka010.*; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.StringDeserializer; +import scala.Tuple2; + +Map kafkaParams = new HashMap<>(); +kafkaParams.put("bootstrap.servers", "localhost:9092,anotherhost:9092"); +kafkaParams.put("key.deserializer", StringDeserializer.class); +kafkaParams.put("value.deserializer", StringDeserializer.class); +kafkaParams.put("group.id", "use_a_separate_group_id_for_each_stream"); +kafkaParams.put("auto.offset.reset", "latest"); +kafkaParams.put("enable.auto.commit", false); + +Collection topics = Arrays.asList("topicA", "topicB"); + +final JavaInputDStream> stream = + KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Subscribe(topics, kafkaParams) + ); + +stream.mapToPair( + new PairFunction, String, String>() { + @Override + public Tuple2 call(ConsumerRecord record) { + return new Tuple2<>(record.key(), record.value()); + } + }) +{% endhighlight %}
@@ -109,32 +112,35 @@ If you have a use case that is better suited to batch processing, you can create
- // Import dependencies and create kafka params as in Create Direct Stream above - - val offsetRanges = Array( - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange("test", 0, 0, 100), - OffsetRange("test", 1, 0, 100) - ) +{% highlight scala %} +// Import dependencies and create kafka params as in Create Direct Stream above - val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +val offsetRanges = Array( + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange("test", 0, 0, 100), + OffsetRange("test", 1, 0, 100) +) +val rdd = KafkaUtils.createRDD[String, String](sparkContext, kafkaParams, offsetRanges, PreferConsistent) +{% endhighlight %}
- // Import dependencies and create kafka params as in Create Direct Stream above - - OffsetRange[] offsetRanges = { - // topic, partition, inclusive starting offset, exclusive ending offset - OffsetRange.create("test", 0, 0, 100), - OffsetRange.create("test", 1, 0, 100) - }; - - JavaRDD> rdd = KafkaUtils.createRDD( - sparkContext, - kafkaParams, - offsetRanges, - LocationStrategies.PreferConsistent() - ); +{% highlight java %} +// Import dependencies and create kafka params as in Create Direct Stream above + +OffsetRange[] offsetRanges = { + // topic, partition, inclusive starting offset, exclusive ending offset + OffsetRange.create("test", 0, 0, 100), + OffsetRange.create("test", 1, 0, 100) +}; + +JavaRDD> rdd = KafkaUtils.createRDD( + sparkContext, + kafkaParams, + offsetRanges, + LocationStrategies.PreferConsistent() +); +{% endhighlight %}
@@ -144,29 +150,33 @@ Note that you cannot use `PreferBrokers`, because without the stream there is no
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - rdd.foreachPartition { iter => - val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) - println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") - } - } +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + rdd.foreachPartition { iter => + val o: OffsetRange = offsetRanges(TaskContext.get.partitionId) + println(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}") + } +} +{% endhighlight %}
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - rdd.foreachPartition(new VoidFunction>>() { - @Override - public void call(Iterator> consumerRecords) { - OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; - System.out.println( - o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); - } - }); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + final OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + rdd.foreachPartition(new VoidFunction>>() { + @Override + public void call(Iterator> consumerRecords) { + OffsetRange o = offsetRanges[TaskContext.get().partitionId()]; + System.out.println( + o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset()); + } + }); + } +}); +{% endhighlight %}
@@ -183,25 +193,28 @@ Kafka has an offset commit API that stores offsets in a special Kafka topic. By
- stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - - // some time later, after outputs have completed - stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) - } - +{% highlight scala %} +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges + + // some time later, after outputs have completed + stream.asInstanceOf[CanCommitOffsets].commitAsync(offsetRanges) +} +{% endhighlight %} As with HasOffsetRanges, the cast to CanCommitOffsets will only succeed if called on the result of createDirectStream, not after transformations. The commitAsync call is threadsafe, but must occur after outputs if you want meaningful semantics.
- stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - // some time later, after outputs have completed - ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); - } - }); +{% highlight java %} +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + // some time later, after outputs have completed + ((CanCommitOffsets) stream.inputDStream()).commitAsync(offsetRanges); + } +}); +{% endhighlight %}
@@ -210,64 +223,68 @@ For data stores that support transactions, saving offsets in the same transactio
- // The details depend on your data store, but the general idea looks like this +{% highlight scala %} +// The details depend on your data store, but the general idea looks like this - // begin from the the offsets committed to the database - val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => - new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") - }.toMap +// begin from the the offsets committed to the database +val fromOffsets = selectOffsetsFromYourDatabase.map { resultSet => + new TopicPartition(resultSet.string("topic"), resultSet.int("partition")) -> resultSet.long("offset") +}.toMap - val stream = KafkaUtils.createDirectStream[String, String]( - streamingContext, - PreferConsistent, - Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) - ) +val stream = KafkaUtils.createDirectStream[String, String]( + streamingContext, + PreferConsistent, + Assign[String, String](fromOffsets.keys.toList, kafkaParams, fromOffsets) +) - stream.foreachRDD { rdd => - val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges +stream.foreachRDD { rdd => + val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges - val results = yourCalculation(rdd) + val results = yourCalculation(rdd) - // begin your transaction + // begin your transaction - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly - // end your transaction - } + // end your transaction +} +{% endhighlight %}
- // The details depend on your data store, but the general idea looks like this - - // begin from the the offsets committed to the database - Map fromOffsets = new HashMap<>(); - for (resultSet : selectOffsetsFromYourDatabase) - fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); - } - - JavaInputDStream> stream = KafkaUtils.createDirectStream( - streamingContext, - LocationStrategies.PreferConsistent(), - ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) - ); - - stream.foreachRDD(new VoidFunction>>() { - @Override - public void call(JavaRDD> rdd) { - OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); - - Object results = yourCalculation(rdd); - - // begin your transaction - - // update results - // update offsets where the end of existing offsets matches the beginning of this batch of offsets - // assert that offsets were updated correctly - - // end your transaction - } - }); +{% highlight java %} +// The details depend on your data store, but the general idea looks like this + +// begin from the the offsets committed to the database +Map fromOffsets = new HashMap<>(); +for (resultSet : selectOffsetsFromYourDatabase) + fromOffsets.put(new TopicPartition(resultSet.string("topic"), resultSet.int("partition")), resultSet.long("offset")); +} + +JavaInputDStream> stream = KafkaUtils.createDirectStream( + streamingContext, + LocationStrategies.PreferConsistent(), + ConsumerStrategies.Assign(fromOffsets.keySet(), kafkaParams, fromOffsets) +); + +stream.foreachRDD(new VoidFunction>>() { + @Override + public void call(JavaRDD> rdd) { + OffsetRange[] offsetRanges = ((HasOffsetRanges) rdd.rdd()).offsetRanges(); + + Object results = yourCalculation(rdd); + + // begin your transaction + + // update results + // update offsets where the end of existing offsets matches the beginning of this batch of offsets + // assert that offsets were updated correctly + + // end your transaction + } +}); +{% endhighlight %}
@@ -277,25 +294,29 @@ The new Kafka consumer [supports SSL](http://kafka.apache.org/documentation.html
- val kafkaParams = Map[String, Object]( - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - "security.protocol" -> "SSL", - "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", - "ssl.truststore.password" -> "test1234", - "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", - "ssl.keystore.password" -> "test1234", - "ssl.key.password" -> "test1234" - ) +{% highlight scala %} +val kafkaParams = Map[String, Object]( + // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS + "security.protocol" -> "SSL", + "ssl.truststore.location" -> "/some-directory/kafka.client.truststore.jks", + "ssl.truststore.password" -> "test1234", + "ssl.keystore.location" -> "/some-directory/kafka.client.keystore.jks", + "ssl.keystore.password" -> "test1234", + "ssl.key.password" -> "test1234" +) +{% endhighlight %}
- Map kafkaParams = new HashMap(); - // the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS - kafkaParams.put("security.protocol", "SSL"); - kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); - kafkaParams.put("ssl.truststore.password", "test1234"); - kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); - kafkaParams.put("ssl.keystore.password", "test1234"); - kafkaParams.put("ssl.key.password", "test1234"); +{% highlight java %} +Map kafkaParams = new HashMap(); +// the usual params, make sure to change the port in bootstrap.servers if 9092 is not TLS +kafkaParams.put("security.protocol", "SSL"); +kafkaParams.put("ssl.truststore.location", "/some-directory/kafka.client.truststore.jks"); +kafkaParams.put("ssl.truststore.password", "test1234"); +kafkaParams.put("ssl.keystore.location", "/some-directory/kafka.client.keystore.jks"); +kafkaParams.put("ssl.keystore.password", "test1234"); +kafkaParams.put("ssl.key.password", "test1234"); +{% endhighlight %}
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 0b0315b366501..18fc1cd934826 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2191,7 +2191,7 @@ consistent batch processing times. Make sure you set the CMS GC on both the driv - When data is received from a stream source, receiver creates blocks of data. A new block of data is generated every blockInterval milliseconds. N blocks of data are created during the batchInterval where N = batchInterval/blockInterval. These blocks are distributed by the BlockManager of the current executor to the block managers of other executors. After that, the Network Input Tracker running on the driver is informed about the block locations for further processing. -- A RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. +- An RDD is created on the driver for the blocks created during the batchInterval. The blocks generated during the batchInterval are partitions of the RDD. Each partition is a task in spark. blockInterval== batchinterval would mean that a single partition is created and probably it is processed locally. - The map tasks on the blocks are processed in the executors (one that received the block, and another where the block was replicated) that has the blocks irrespective of block interval, unless non-local scheduling kicks in. Having bigger blockinterval means bigger blocks. A high value of `spark.locality.wait` increases the chance of processing a block on the local node. A balance needs to be found out between these two parameters to ensure that the bigger blocks are processed locally. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index a6c3b3a9024d8..2458bb5ffa298 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -19,97 +19,103 @@ application. See the [Deploying](#deploying) subsection below.
+{% highlight scala %} - // Subscribe to 1 topic - val ds1 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to 1 topic +val ds1 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to multiple topics - val ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to multiple topics +val ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] - // Subscribe to a pattern - val ds3 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - .as[(String, String)] +// Subscribe to a pattern +val ds3 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .as[(String, String)] +{% endhighlight %}
+{% highlight java %} - // Subscribe to 1 topic - Dataset ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to 1 topic +Dataset ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to multiple topics - Dataset ds2 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to multiple topics +Dataset ds2 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - // Subscribe to a pattern - Dataset ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +// Subscribe to a pattern +Dataset ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
+{% highlight python %} - # Subscribe to 1 topic - ds1 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1") - .load() - ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to 1 topic +ds1 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1") + .load() +ds1.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to multiple topics - ds2 = spark - .readStream - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribe", "topic1,topic2") - .load() - ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to multiple topics +ds2 = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribe", "topic1,topic2") + .load() +ds2.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") - # Subscribe to a pattern - ds3 = spark - .readStream() - .format("kafka") - .option("kafka.bootstrap.servers", "host1:port1,host2:port2") - .option("subscribePattern", "topic.*") - .load() - ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +# Subscribe to a pattern +ds3 = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", "host1:port1,host2:port2") + .option("subscribePattern", "topic.*") + .load() +ds3.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") +{% endhighlight %}
@@ -234,6 +240,7 @@ Kafka's own configurations can be set via `DataStreamReader.option` with `kafka. [Kafka consumer config docs](http://kafka.apache.org/documentation.html#newconsumerconfigs). Note that the following Kafka params cannot be set and the Kafka source will throw an exception: + - **group.id**: Kafka source will create a unique group id for each query automatically. - **auto.offset.reset**: Set the source option `startingOffsets` to specify where to start instead. Structured Streaming manages which offsets are consumed internally, rather diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index d838ed35a14fd..77b66b3b3a497 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -58,7 +58,7 @@ from pyspark.sql.functions import explode from pyspark.sql.functions import split spark = SparkSession \ - .builder() \ + .builder \ .appName("StructuredNetworkWordCount") \ .getOrCreate() {% endhighlight %} @@ -1087,9 +1087,185 @@ spark.streams().awaitAnyTermination() # block until any one of them terminates
-Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` -([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), -which will give you regular callback-based updates when queries are started and terminated. + +## Monitoring Streaming Queries +There are two ways you can monitor queries. You can directly get the current status +of an active query using `streamingQuery.status`, which will return a `StreamingQueryStatus` object +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryStatus)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryStatus.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryStatus) docs) +that has all the details like current ingestion rates, processing rates, average latency, +details of the currently active trigger, etc. + +
+
+ +{% highlight scala %} +val query: StreamingQuery = ... + +println(query.status) + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight java %} +StreamingQuery query = ... + +System.out.println(query.status); + +/* Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +*/ +{% endhighlight %} + +
+
+ +{% highlight python %} +query = ... // a StreamingQuery + +print(query.status) + +''' +Will print the current status of the query + +Status of query 'queryName' + Query id: 1 + Status timestamp: 123 + Input rate: 15.5 rows/sec + Processing rate 23.5 rows/sec + Latency: 345.0 ms + Trigger details: + batchId: 5 + isDataPresentInTrigger: true + isTriggerActive: true + latency.getBatch.total: 20 + latency.getOffset.total: 10 + numRows.input.total: 100 + Source statuses [1 source]: + Source 1 - MySource1 + Available offset: 0 + Input rate: 15.5 rows/sec + Processing rate: 23.5 rows/sec + Trigger details: + numRows.input.source: 100 + latency.getOffset.source: 10 + latency.getBatch.source: 20 + Sink status - MySink + Committed offsets: [1, -] +''' +{% endhighlight %} + +
+
+ + +You can also asynchronously monitor all queries associated with a +`SparkSession` by attaching a `StreamingQueryListener` +([Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs). +Once you attach your custom `StreamingQueryListener` object with +`sparkSession.streams.attachListener()`, you will get callbacks when a query is started and +stopped and when there is progress made in an active query. Here is an example, + +
+
+ +{% highlight scala %} +val spark: SparkSession = ... + +spark.streams.addListener(new StreamingQueryListener() { + + override def onQueryStarted(queryStarted: QueryStartedEvent): Unit = { + println("Query started: " + queryTerminated.queryStatus.name) + } + override def onQueryTerminated(queryTerminated: QueryTerminatedEvent): Unit = { + println("Query terminated: " + queryTerminated.queryStatus.name) + } + override def onQueryProgress(queryProgress: QueryProgressEvent): Unit = { + println("Query made progress: " + queryProgress.queryStatus) + } +}) +{% endhighlight %} + +
+
+ +{% highlight java %} +SparkSession spark = ... + +spark.streams.addListener(new StreamingQueryListener() { + + @Overrides void onQueryStarted(QueryStartedEvent queryStarted) { + System.out.println("Query started: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryTerminated(QueryTerminatedEvent queryTerminated) { + System.out.println("Query terminated: " + queryTerminated.queryStatus.name); + } + @Overrides void onQueryProgress(QueryProgressEvent queryProgress) { + System.out.println("Query made progress: " + queryProgress.queryStatus); + } +}); +{% endhighlight %} + +
+
+{% highlight bash %} +Not available in Python. +{% endhighlight %} + +
+
## Recovering from Failures with Checkpointing In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). diff --git a/docs/tuning.md b/docs/tuning.md index 9c43b315bbb9e..0de303a3bd9bf 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -224,8 +224,8 @@ temporary objects created during task execution. Some steps which may be useful * As an example, if your task is reading data from HDFS, the amount of memory used by the task can be estimated using the size of the data block read from HDFS. Note that the size of a decompressed block is often 2 or 3 times the - size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 64 MB, - we can estimate size of Eden to be `4*3*64MB`. + size of the block. So if we wish to have 3 or 4 tasks' worth of working space, and the HDFS block size is 128 MB, + we can estimate size of Eden to be `4*3*128MB`. * Monitor how the frequency and time taken by garbage collection changes with the new settings. diff --git a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java index 7df145e3117b8..89855e81f1f7a 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaSparkPi.java @@ -54,7 +54,7 @@ public static void main(String[] args) throws Exception { public Integer call(Integer integer) { double x = Math.random() * 2 - 1; double y = Math.random() * 2 - 1; - return (x * x + y * y < 1) ? 1 : 0; + return (x * x + y * y <= 1) ? 1 : 0; } }).reduce(new Function2() { @Override diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java new file mode 100644 index 0000000000000..3684a87e22e7b --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaInteractionExample.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.ml.feature.Interaction; +import org.apache.spark.ml.feature.VectorAssembler; +import org.apache.spark.sql.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.Arrays; +import java.util.List; + +// $example on$ +// $example off$ + +public class JavaInteractionExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaInteractionExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1, 1, 2, 3, 8, 4, 5), + RowFactory.create(2, 4, 3, 8, 7, 9, 8), + RowFactory.create(3, 6, 1, 9, 2, 3, 6), + RowFactory.create(4, 10, 8, 6, 9, 4, 5), + RowFactory.create(5, 9, 2, 7, 10, 7, 3), + RowFactory.create(6, 1, 1, 4, 2, 8, 4) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("id1", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id2", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id3", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id4", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id5", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id6", DataTypes.IntegerType, false, Metadata.empty()), + new StructField("id7", DataTypes.IntegerType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + VectorAssembler assembler1 = new VectorAssembler() + .setInputCols(new String[]{"id2", "id3", "id4"}) + .setOutputCol("vec1"); + + Dataset assembled1 = assembler1.transform(df); + + VectorAssembler assembler2 = new VectorAssembler() + .setInputCols(new String[]{"id5", "id6", "id7"}) + .setOutputCol("vec2"); + + Dataset assembled2 = assembler2.transform(assembled1).select("id1", "vec1", "vec2"); + + Interaction interaction = new Interaction() + .setInputCols(new String[]{"id1","vec1","vec2"}) + .setOutputCol("interactedCol"); + + Dataset interacted = interaction.transform(assembled2); + + interacted.show(false); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java index b8fb5972ea418..4cdec21d23023 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java @@ -60,8 +60,8 @@ public static void main(String[] args) { LogisticRegressionModel mlrModel = mlr.fit(training); // Print the coefficients and intercepts for logistic regression with multinomial family - System.out.println("Multinomial coefficients: " - + lrModel.coefficientMatrix() + "\nMultinomial intercepts: " + mlrModel.interceptVector()); + System.out.println("Multinomial coefficients: " + lrModel.coefficientMatrix() + + "\nMultinomial intercepts: " + mlrModel.interceptVector()); // $example off$ spark.stop(); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java new file mode 100644 index 0000000000000..0f45cfeca4429 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredKafkaWordCount.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.sql.streaming; + +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.streaming.StreamingQuery; + +import java.util.Arrays; +import java.util.Iterator; + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: JavaStructuredKafkaWordCount + * The Kafka "bootstrap.servers" configuration. A + * comma-separated list of host:port. + * There are three kinds of type, i.e. 'assign', 'subscribe', + * 'subscribePattern'. + * |- Specific TopicPartitions to consume. Json string + * | {"topicA":[0,1],"topicB":[2,4]}. + * |- The topic list to subscribe. A comma-separated list of + * | topics. + * |- The pattern used to subscribe to topic(s). + * | Java regex string. + * |- Only one of "assign, "subscribe" or "subscribePattern" options can be + * | specified for Kafka source. + * Different value format depends on the value of 'subscribe-type'. + * + * Example: + * `$ bin/run-example \ + * sql.streaming.JavaStructuredKafkaWordCount host1:port1,host2:port2 \ + * subscribe topic1,topic2` + */ +public final class JavaStructuredKafkaWordCount { + + public static void main(String[] args) throws Exception { + if (args.length < 3) { + System.err.println("Usage: JavaStructuredKafkaWordCount " + + " "); + System.exit(1); + } + + String bootstrapServers = args[0]; + String subscribeType = args[1]; + String topics = args[2]; + + SparkSession spark = SparkSession + .builder() + .appName("JavaStructuredKafkaWordCount") + .getOrCreate(); + + // Create DataSet representing the stream of input lines from kafka + Dataset lines = spark + .readStream() + .format("kafka") + .option("kafka.bootstrap.servers", bootstrapServers) + .option(subscribeType, topics) + .load() + .selectExpr("CAST(value AS STRING)") + .as(Encoders.STRING()); + + // Generate running word count + Dataset wordCounts = lines.flatMap(new FlatMapFunction() { + @Override + public Iterator call(String x) { + return Arrays.asList(x.split(" ")).iterator(); + } + }, Encoders.STRING()).groupBy("value").count(); + + // Start running the query that prints the running counts to the console + StreamingQuery query = wordCounts.writeStream() + .outputMode("complete") + .format("console") + .start(); + + query.awaitTermination(); + } +} diff --git a/examples/src/main/python/mllib/k_means_example.py b/examples/src/main/python/mllib/k_means_example.py index 5c397e62ef10e..d6058f45020c4 100644 --- a/examples/src/main/python/mllib/k_means_example.py +++ b/examples/src/main/python/mllib/k_means_example.py @@ -36,8 +36,7 @@ parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')])) # Build the model (cluster the data) - clusters = KMeans.train(parsedData, 2, maxIterations=10, - runs=10, initializationMode="random") + clusters = KMeans.train(parsedData, 2, maxIterations=10, initializationMode="random") # Evaluate clustering by computing Within Set Sum of Squared Errors def error(point): diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index e3f0c4aeef1b7..37029b76798f6 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -38,7 +38,7 @@ def f(_): x = random() * 2 - 1 y = random() * 2 - 1 - return 1 if x ** 2 + y ** 2 < 1 else 0 + return 1 if x ** 2 + y ** 2 <= 1 else 0 count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add) print("Pi is roughly %f" % (4.0 * count / n)) diff --git a/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py new file mode 100644 index 0000000000000..9e8a552b3b10b --- /dev/null +++ b/examples/src/main/python/sql/streaming/structured_kafka_wordcount.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Consumes messages from one or more topics in Kafka and does wordcount. + Usage: structured_kafka_wordcount.py + The Kafka "bootstrap.servers" configuration. A + comma-separated list of host:port. + There are three kinds of type, i.e. 'assign', 'subscribe', + 'subscribePattern'. + |- Specific TopicPartitions to consume. Json string + | {"topicA":[0,1],"topicB":[2,4]}. + |- The topic list to subscribe. A comma-separated list of + | topics. + |- The pattern used to subscribe to topic(s). + | Java regex string. + |- Only one of "assign, "subscribe" or "subscribePattern" options can be + | specified for Kafka source. + Different value format depends on the value of 'subscribe-type'. + + Run the example + `$ bin/spark-submit examples/src/main/python/sql/streaming/structured_kafka_wordcount.py \ + host1:port1,host2:port2 subscribe topic1,topic2` +""" +from __future__ import print_function + +import sys + +from pyspark.sql import SparkSession +from pyspark.sql.functions import explode +from pyspark.sql.functions import split + +if __name__ == "__main__": + if len(sys.argv) != 4: + print(""" + Usage: structured_kafka_wordcount.py + """, file=sys.stderr) + exit(-1) + + bootstrapServers = sys.argv[1] + subscribeType = sys.argv[2] + topics = sys.argv[3] + + spark = SparkSession\ + .builder\ + .appName("StructuredKafkaWordCount")\ + .getOrCreate() + + # Create DataSet representing the stream of input lines from kafka + lines = spark\ + .readStream\ + .format("kafka")\ + .option("kafka.bootstrap.servers", bootstrapServers)\ + .option(subscribeType, topics)\ + .load()\ + .selectExpr("CAST(value AS STRING)") + + # Split the lines into words + words = lines.select( + # explode turns each item in an array into a separate row + explode( + split(lines.value, ' ') + ).alias('word') + ) + + # Generate running word count + wordCounts = words.groupBy('word').count() + + # Start running the query that prints the running counts to the console + query = wordCounts\ + .writeStream\ + .outputMode('complete')\ + .format('console')\ + .start() + + query.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 3d02ce05619ad..a897cad02ffd7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -51,7 +51,8 @@ object LocalFileLR { showWarning() - val lines = scala.io.Source.fromFile(args(0)).getLines().toArray + val fileSrc = scala.io.Source.fromFile(args(0)) + val lines = fileSrc.getLines().toArray val points = lines.map(parsePoint _) val ITERATIONS = args(1).toInt @@ -69,6 +70,7 @@ object LocalFileLR { w -= gradient } + fileSrc.close() println("Final w: " + w) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala index 720d92fb9d029..121b768e4198e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalPi.scala @@ -26,7 +26,7 @@ object LocalPi { for (i <- 1 to 100000) { val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) count += 1 + if (x*x + y*y <= 1) count += 1 } println("Pi is roughly " + 4 * count / 100000.0) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index 272c1a4fc2f47..a5cacf17a5cca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -34,7 +34,7 @@ object SparkPi { val count = spark.sparkContext.parallelize(1 until n, slices).map { i => val x = random * 2 - 1 val y = random * 2 - 1 - if (x*x + y*y < 1) 1 else 0 + if (x*x + y*y <= 1) 1 else 0 }.reduce(_ + _) println("Pi is roughly " + 4.0 * count / (n - 1)) spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/InteractionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/InteractionExample.scala new file mode 100644 index 0000000000000..8113c992b1d69 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/InteractionExample.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.Interaction +import org.apache.spark.ml.feature.VectorAssembler +// $example off$ +import org.apache.spark.sql.SparkSession + +object InteractionExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName("InteractionExample") + .getOrCreate() + + // $example on$ + val df = spark.createDataFrame(Seq( + (1, 1, 2, 3, 8, 4, 5), + (2, 4, 3, 8, 7, 9, 8), + (3, 6, 1, 9, 2, 3, 6), + (4, 10, 8, 6, 9, 4, 5), + (5, 9, 2, 7, 10, 7, 3), + (6, 1, 1, 4, 2, 8, 4) + )).toDF("id1", "id2", "id3", "id4", "id5", "id6", "id7") + + val assembler1 = new VectorAssembler(). + setInputCols(Array("id2", "id3", "id4")). + setOutputCol("vec1") + + val assembled1 = assembler1.transform(df) + + val assembler2 = new VectorAssembler(). + setInputCols(Array("id5", "id6", "id7")). + setOutputCol("vec2") + + val assembled2 = assembler2.transform(assembled1).select("id1", "vec1", "vec2") + + val interaction = new Interaction() + .setInputCols(Array("id1", "vec1", "vec2")) + .setOutputCol("interactedCol") + + val interacted = interaction.transform(assembled2) + + interacted.show(truncate = false) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala new file mode 100644 index 0000000000000..c26f73e788814 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredKafkaWordCount.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.sql.streaming + +import org.apache.spark.sql.SparkSession + +/** + * Consumes messages from one or more topics in Kafka and does wordcount. + * Usage: StructuredKafkaWordCount + * The Kafka "bootstrap.servers" configuration. A + * comma-separated list of host:port. + * There are three kinds of type, i.e. 'assign', 'subscribe', + * 'subscribePattern'. + * |- Specific TopicPartitions to consume. Json string + * | {"topicA":[0,1],"topicB":[2,4]}. + * |- The topic list to subscribe. A comma-separated list of + * | topics. + * |- The pattern used to subscribe to topic(s). + * | Java regex string. + * |- Only one of "assign, "subscribe" or "subscribePattern" options can be + * | specified for Kafka source. + * Different value format depends on the value of 'subscribe-type'. + * + * Example: + * `$ bin/run-example \ + * sql.streaming.StructuredKafkaWordCount host1:port1,host2:port2 \ + * subscribe topic1,topic2` + */ +object StructuredKafkaWordCount { + def main(args: Array[String]): Unit = { + if (args.length < 3) { + System.err.println("Usage: StructuredKafkaWordCount " + + " ") + System.exit(1) + } + + val Array(bootstrapServers, subscribeType, topics) = args + + val spark = SparkSession + .builder + .appName("StructuredKafkaWordCount") + .getOrCreate() + + import spark.implicits._ + + // Create DataSet representing the stream of input lines from kafka + val lines = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", bootstrapServers) + .option(subscribeType, topics) + .load() + .selectExpr("CAST(value AS STRING)") + .as[String] + + // Generate running word count + val wordCounts = lines.flatMap(_.split(" ")).groupBy("value").count() + + // Start running the query that prints the running counts to the console + val query = wordCounts.writeStream + .outputMode("complete") + .format("console") + .start() + + query.awaitTermination() + } + +} +// scalastyle:on println diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala index 40d568a12c25d..13d717092a898 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/JsonUtils.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.kafka010 -import java.io.Writer - import scala.collection.mutable.HashMap import scala.util.control.NonFatal diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala index 61cba737d148a..341081a338c0e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.io._ +import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -88,7 +90,10 @@ private[kafka010] case class KafkaSource( private val sc = sqlContext.sparkContext - private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong + private val pollTimeoutMs = sourceOptions.getOrElse( + "kafkaConsumer.pollTimeoutMs", + sc.conf.getTimeAsMs("spark.network.timeout", "120s").toString + ).toLong private val maxOffsetFetchAttempts = sourceOptions.getOrElse("fetchOffset.numRetries", "3").toInt @@ -111,7 +116,22 @@ private[kafka010] case class KafkaSource( * `KafkaConsumer.poll` may hang forever (KAFKA-1894). */ private lazy val initialPartitionOffsets = { - val metadataLog = new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) + val metadataLog = + new HDFSMetadataLog[KafkaSourceOffset](sqlContext.sparkSession, metadataPath) { + override def serialize(metadata: KafkaSourceOffset, out: OutputStream): Unit = { + val bytes = metadata.json.getBytes(StandardCharsets.UTF_8) + out.write(bytes.length) + out.write(bytes) + } + + override def deserialize(in: InputStream): KafkaSourceOffset = { + val length = in.read() + val bytes = new Array[Byte](length) + in.read(bytes) + KafkaSourceOffset(SerializedOffset(new String(bytes, StandardCharsets.UTF_8))) + } + } + metadataLog.get(0).getOrElse { val offsets = startingOffsets match { case EarliestOffsets => KafkaSourceOffset(fetchEarliestOffsets()) @@ -259,7 +279,7 @@ private[kafka010] case class KafkaSource( } }.toArray - // Create a RDD that reads from Kafka and get the (key, value) pair as byte arrays. + // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. val rdd = new KafkaSourceRDD( sc, executorKafkaParams, offsetRanges, pollTimeoutMs).map { cr => Row(cr.key, cr.value, cr.topic, cr.partition, cr.offset, cr.timestamp, cr.timestampType.id) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index b5ade982515f0..b5da415b3097e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition -import org.apache.spark.sql.execution.streaming.Offset +import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and @@ -27,9 +27,8 @@ import org.apache.spark.sql.execution.streaming.Offset */ private[kafka010] case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { - override def toString(): String = { - partitionToOffsets.toSeq.sortBy(_._1.toString).mkString("[", ", ", "]") - } + + override val json = JsonUtils.partitionOffsets(partitionToOffsets) } /** Companion object of the [[KafkaSourceOffset]] */ @@ -38,6 +37,7 @@ private[kafka010] object KafkaSourceOffset { def getPartitionOffsets(offset: Offset): Map[TopicPartition, Long] = { offset match { case o: KafkaSourceOffset => o.partitionToOffsets + case so: SerializedOffset => KafkaSourceOffset(so).partitionToOffsets case _ => throw new IllegalArgumentException( s"Invalid conversion from offset of ${offset.getClass} to KafkaSourceOffset") @@ -51,4 +51,10 @@ private[kafka010] object KafkaSourceOffset { def apply(offsetTuples: (String, Int, Long)*): KafkaSourceOffset = { KafkaSourceOffset(offsetTuples.map { case(t, p, o) => (new TopicPartition(t, p), o) }.toMap) } + + /** + * Returns [[KafkaSourceOffset]] from a JSON [[SerializedOffset]] + */ + def apply(offset: SerializedOffset): KafkaSourceOffset = + KafkaSourceOffset(JsonUtils.partitionOffsets(offset.json)) } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala index 7056a41b1751e..881018fd95665 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceOffsetSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.kafka010 +import java.io.File + +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.streaming.OffsetSuite +import org.apache.spark.sql.test.SharedSQLContext -class KafkaSourceOffsetSuite extends OffsetSuite { +class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext { compare( one = KafkaSourceOffset(("t", 0, 1L)), @@ -36,4 +40,53 @@ class KafkaSourceOffsetSuite extends OffsetSuite { compare( one = KafkaSourceOffset(("t", 0, 1L)), two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L))) + + + val kso1 = KafkaSourceOffset(("t", 0, 1L)) + val kso2 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L)) + val kso3 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L), ("t", 1, 4L)) + + compare(KafkaSourceOffset(SerializedOffset(kso1.json)), + KafkaSourceOffset(SerializedOffset(kso2.json))) + + test("basic serialization - deserialization") { + assert(KafkaSourceOffset.getPartitionOffsets(kso1) == + KafkaSourceOffset.getPartitionOffsets(SerializedOffset(kso1.json))) + } + + + testWithUninterruptibleThread("OffsetSeqLog serialization - deserialization") { + withTempDir { temp => + // use non-existent directory to test whether log make the dir + val dir = new File(temp, "dir") + val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath) + val batch0 = OffsetSeq.fill(kso1) + val batch1 = OffsetSeq.fill(kso2, kso3) + + val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o => + SerializedOffset(o.json))): _*) + + val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o => + SerializedOffset(o.json))): _*) + + assert(metadataLog.add(0, batch0)) + assert(metadataLog.getLatest() === Some(0 -> batch0Serialized)) + assert(metadataLog.get(0) === Some(batch0Serialized)) + + assert(metadataLog.add(1, batch1)) + assert(metadataLog.get(0) === Some(batch0Serialized)) + assert(metadataLog.get(1) === Some(batch1Serialized)) + assert(metadataLog.getLatest() === Some(1 -> batch1Serialized)) + assert(metadataLog.get(None, Some(1)) === + Array(0 -> batch0Serialized, 1 -> batch1Serialized)) + + // Adding the same batch does nothing + metadataLog.add(1, OffsetSeq.fill(LongOffset(3))) + assert(metadataLog.get(0) === Some(batch0Serialized)) + assert(metadataLog.get(1) === Some(batch1Serialized)) + assert(metadataLog.getLatest() === Some(1 -> batch1Serialized)) + assert(metadataLog.get(None, Some(1)) === + Array(0 -> batch0Serialized, 1 -> batch1Serialized)) + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index ed4cc75920e8e..89e713f92df46 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -306,6 +306,30 @@ class KafkaSourceSuite extends KafkaSourceTest { ) } + test("starting offset is latest by default") { + val topic = newTopic() + testUtils.createTopic(topic, partitions = 5) + testUtils.sendMessages(topic, Array("0")) + require(testUtils.getLatestOffsets(Set(topic)).size === 5) + + val reader = spark + .readStream + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("subscribe", topic) + + val kafka = reader.load() + .selectExpr("CAST(value AS STRING)") + .as[String] + val mapped = kafka.map(_.toInt) + + testStream(mapped)( + makeSureGetOffsetCalled, + AddKafkaData(Set(topic), 1, 2, 3), + CheckAnswer(1, 2, 3) // should not have 0 + ) + } + test("bad source options") { def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = { val ex = intercept[IllegalArgumentException] { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 7e57bb18cbd50..794f53c5abfd0 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -57,7 +57,8 @@ import org.apache.spark.streaming.scheduler.rate.RateEstimator private[spark] class DirectKafkaInputDStream[K, V]( _ssc: StreamingContext, locationStrategy: LocationStrategy, - consumerStrategy: ConsumerStrategy[K, V] + consumerStrategy: ConsumerStrategy[K, V], + ppc: PerPartitionConfig ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets { val executorKafkaParams = { @@ -128,12 +129,9 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( - "spark.streaming.kafka.maxRatePerPartition", 0) - protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { @@ -144,11 +142,12 @@ private[spark] class DirectKafkaInputDStream[K, V]( val totalLag = lagPerPartition.values.sum lagPerPartition.map { case (tp, lag) => + val maxRateLimitPerPartition = ppc.maxRatePerPartition(tp) val backpressureRate = Math.round(lag / totalLag.toFloat * rate) tp -> (if (maxRateLimitPerPartition > 0) { Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate) } - case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition } + case None => offsets.map { case (tp, offset) => tp -> ppc.maxRatePerPartition(tp) } } if (effectiveRateLimitPerPartition.values.sum > 0) { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 5b5a9ac48c7ca..98394251bb233 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -66,7 +66,8 @@ private[spark] class KafkaRDD[K, V]( " must be set to false for executor kafka params, else offsets may commit before processing") // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time? - private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", 512) + private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", + conf.getTimeAsMs("spark.network.timeout", "120s")) private val cacheInitialCapacity = conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16) private val cacheMaxCapacity = diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala index b2190bfa05a3a..c11917f59d5b8 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala @@ -123,7 +123,31 @@ object KafkaUtils extends Logging { locationStrategy: LocationStrategy, consumerStrategy: ConsumerStrategy[K, V] ): InputDStream[ConsumerRecord[K, V]] = { - new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy) + val ppc = new DefaultPerPartitionConfig(ssc.sparkContext.getConf) + createDirectStream[K, V](ssc, locationStrategy, consumerStrategy, ppc) + } + + /** + * :: Experimental :: + * Scala constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details. + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + ssc: StreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): InputDStream[ConsumerRecord[K, V]] = { + new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy, perPartitionConfig) } /** @@ -150,6 +174,33 @@ object KafkaUtils extends Logging { jssc.ssc, locationStrategy, consumerStrategy)) } + /** + * :: Experimental :: + * Java constructor for a DStream where + * each given Kafka topic/partition corresponds to an RDD partition. + * @param keyClass Class of the keys in the Kafka records + * @param valueClass Class of the values in the Kafka records + * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent, + * see [[LocationStrategies]] for more details. + * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe, + * see [[ConsumerStrategies]] for more details + * @param perPartitionConfig configuration of settings such as max rate on a per-partition basis. + * see [[PerPartitionConfig]] for more details. + * @tparam K type of Kafka message key + * @tparam V type of Kafka message value + */ + @Experimental + def createDirectStream[K, V]( + jssc: JavaStreamingContext, + locationStrategy: LocationStrategy, + consumerStrategy: ConsumerStrategy[K, V], + perPartitionConfig: PerPartitionConfig + ): JavaInputDStream[ConsumerRecord[K, V]] = { + new JavaInputDStream( + createDirectStream[K, V]( + jssc.ssc, locationStrategy, consumerStrategy, perPartitionConfig)) + } + /** * Tweak kafka params to prevent issues on executors */ diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala new file mode 100644 index 0000000000000..4792f2a955110 --- /dev/null +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/PerPartitionConfig.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka010 + +import org.apache.kafka.common.TopicPartition + +import org.apache.spark.SparkConf +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Interface for user-supplied configurations that can't otherwise be set via Spark properties, + * because they need tweaking on a per-partition basis, + */ +@Experimental +abstract class PerPartitionConfig extends Serializable { + /** + * Maximum rate (number of records per second) at which data will be read + * from each Kafka partition. + */ + def maxRatePerPartition(topicPartition: TopicPartition): Long +} + +/** + * Default per-partition configuration + */ +private class DefaultPerPartitionConfig(conf: SparkConf) + extends PerPartitionConfig { + val maxRate = conf.getLong("spark.streaming.kafka.maxRatePerPartition", 0) + + def maxRatePerPartition(topicPartition: TopicPartition): Long = maxRate +} diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index 02aec43c3b34f..fde3714d3d02e 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -252,7 +252,8 @@ class DirectKafkaStreamSuite val s = new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -272,6 +273,7 @@ class DirectKafkaStreamSuite collectedData.contains("b") } assert(!collectedData.contains("a")) + ssc.stop() } @@ -306,7 +308,8 @@ class DirectKafkaStreamSuite ConsumerStrategies.Assign[String, String]( List(topicPartition), kafkaParams.asScala, - Map(topicPartition -> 11L))) + Map(topicPartition -> 11L)), + new DefaultPerPartitionConfig(sparkConf)) s.consumer.poll(0) assert( s.consumer.position(topicPartition) >= offsetBeforeStart, @@ -324,6 +327,7 @@ class DirectKafkaStreamSuite collectedData.contains("b") } assert(!collectedData.contains("a")) + ssc.stop() } // Test to verify the offset ranges can be recovered from the checkpoints @@ -518,7 +522,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with backpressure disabled") { val topic = "maxMessagesPerPartition" - val kafkaStream = getDirectKafkaStream(topic, None) + val kafkaStream = getDirectKafkaStream(topic, None, None) val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L) assert(kafkaStream.maxMessagesPerPartition(input).get == @@ -528,7 +532,7 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition with no lag") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val kafkaStream = getDirectKafkaStream(topic, rateController, None) val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L) assert(kafkaStream.maxMessagesPerPartition(input).isEmpty) @@ -537,11 +541,19 @@ class DirectKafkaStreamSuite test("maxMessagesPerPartition respects max rate") { val topic = "maxMessagesPerPartition" val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000)) - val kafkaStream = getDirectKafkaStream(topic, rateController) + val ppc = Some(new PerPartitionConfig { + def maxRatePerPartition(tp: TopicPartition) = + if (tp.topic == topic && tp.partition == 0) { + 50 + } else { + 100 + } + }) + val kafkaStream = getDirectKafkaStream(topic, rateController, ppc) val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L) assert(kafkaStream.maxMessagesPerPartition(input).get == - Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L)) + Map(new TopicPartition(topic, 0) -> 5L, new TopicPartition(topic, 1) -> 10L)) } test("using rate controller") { @@ -570,7 +582,9 @@ class DirectKafkaStreamSuite new DirectKafkaInputDStream[String, String]( ssc, preferredHosts, - ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) { + ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala), + new DefaultPerPartitionConfig(sparkConf) + ) { override protected[streaming] val rateController = Some(new DirectKafkaRateController(id, estimator)) }.map(r => (r.key, r.value)) @@ -616,7 +630,10 @@ class DirectKafkaStreamSuite }.toSeq.sortBy { _._1 } } - private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = { + private def getDirectKafkaStream( + topic: String, + mockRateController: Option[RateController], + ppc: Option[PerPartitionConfig]) = { val batchIntervalMilliseconds = 100 val sparkConf = new SparkConf() @@ -643,7 +660,8 @@ class DirectKafkaStreamSuite tps.foreach(tp => consumer.seek(tp, 0)) consumer } - } + }, + ppc.getOrElse(new DefaultPerPartitionConfig(sparkConf)) ) { override protected[streaming] val rateController = mockRateController } diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala index c3c799375bbeb..d52c230eb7849 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala @@ -88,12 +88,12 @@ class DirectKafkaInputDStream[ protected val kc = new KafkaCluster(kafkaParams) - private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt( + private val maxRateLimitPerPartition: Long = context.sparkContext.getConf.getLong( "spark.streaming.kafka.maxRatePerPartition", 0) protected[streaming] def maxMessagesPerPartition( offsets: Map[TopicAndPartition, Long]): Option[Map[TopicAndPartition, Long]] = { - val estimatedRateLimit = rateController.map(_.getLatestRate().toInt) + val estimatedRateLimit = rateController.map(_.getLatestRate()) // calculate a per-partition rate limit based on current lag val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match { diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index b17e198077949..56f0cb0b166a2 100644 --- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -223,7 +223,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param sc SparkContext object * @param kafkaParams Kafka @@ -255,7 +255,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * @@ -303,7 +303,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. + * Create an RDD from Kafka using offset ranges for each topic and partition. * * @param jsc JavaSparkContext object * @param kafkaParams Kafka @@ -340,7 +340,7 @@ object KafkaUtils { } /** - * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you + * Create an RDD from Kafka using offset ranges for each topic and partition. This allows you * specify the Kafka leader to connect to (to optimize fetching) and access the message as well * as the metadata. * diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index ab1c5055a253f..8a747a5e29b58 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -184,6 +184,7 @@ class DirectKafkaStreamSuite collectedData.contains("b") } assert(!collectedData.contains("a")) + ssc.stop() } @@ -230,6 +231,7 @@ class DirectKafkaStreamSuite collectedData.contains("b") } assert(!collectedData.contains("a")) + ssc.stop() } // Test to verify the offset ranges can be recovered from the checkpoints diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 6a35ac14a8f6f..426cd83b4ddf8 100644 --- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -80,5 +80,6 @@ class KafkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { assert(result.synchronized { sent === result }) } + ssc.stop() } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 80e0cce055862..a0ccd086d90fa 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -27,7 +27,6 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.internal.Logging -import org.apache.spark.streaming.Duration /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. @@ -102,27 +101,32 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) */ - override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { + override def shutdown( + checkpointer: IRecordProcessorCheckpointer, + reason: ShutdownReason): Unit = { logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") - reason match { - /* - * TERMINATE Use Case. Checkpoint. - * Checkpoint to indicate that all records from the shard have been drained and processed. - * It's now OK to read from the new shards that resulted from a resharding event. - */ - case ShutdownReason.TERMINATE => - receiver.removeCheckpointer(shardId, checkpointer) + // null if not initialized before shutdown: + if (shardId == null) { + logWarning(s"No shardId for workerId $workerId?") + } else { + reason match { + /* + * TERMINATE Use Case. Checkpoint. + * Checkpoint to indicate that all records from the shard have been drained and processed. + * It's now OK to read from the new shards that resulted from a resharding event. + */ + case ShutdownReason.TERMINATE => receiver.removeCheckpointer(shardId, checkpointer) - /* - * ZOMBIE Use Case or Unknown reason. NoOp. - * No checkpoint because other workers may have taken over and already started processing - * the same records. - * This may lead to records being processed more than once. - */ - case _ => - receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint + /* + * ZOMBIE Use Case or Unknown reason. NoOp. + * No checkpoint because other workers may have taken over and already started processing + * the same records. + * This may lead to records being processed more than once. + * Return null so that we don't checkpoint + */ + case _ => receiver.removeCheckpointer(shardId, null) + } } - } } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index a0007d33d6257..b2daffa34ccbf 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -33,10 +33,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -57,6 +53,10 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T: ClassTag]( ssc: StreamingContext, @@ -81,10 +81,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -107,6 +103,9 @@ object KinesisUtils { * Kinesis `Record`, which contains both message data, and metadata. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T: ClassTag]( @@ -134,10 +133,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -156,6 +151,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( ssc: StreamingContext, @@ -178,10 +177,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param ssc StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -202,6 +197,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( ssc: StreamingContext, @@ -225,10 +223,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -250,6 +244,10 @@ object KinesisUtils { * @param messageHandler A custom message handler that can generate a generic output from a * Kinesis `Record`, which contains both message data, and metadata. * @param recordClass Class of the records in DStream + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream[T]( jssc: JavaStreamingContext, @@ -272,10 +270,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -299,6 +293,9 @@ object KinesisUtils { * @param recordClass Class of the records in DStream * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ // scalastyle:off def createStream[T]( @@ -326,10 +323,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain - * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain - * gets the AWS credentials. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -348,6 +341,10 @@ object KinesisUtils { * details on the different types of checkpoints. * @param storageLevel Storage level to use for storing the received objects. * StorageLevel.MEMORY_AND_DISK_2 is recommended. + * + * @note The AWS credentials will be discovered using the DefaultAWSCredentialsProviderChain + * on the workers. See AWS documentation to understand how DefaultAWSCredentialsProviderChain + * gets the AWS credentials. */ def createStream( jssc: JavaStreamingContext, @@ -367,10 +364,6 @@ object KinesisUtils { * Create an input stream that pulls messages from a Kinesis stream. * This uses the Kinesis Client Library (KCL) to pull messages from Kinesis. * - * Note: - * The given AWS credentials will get saved in DStream checkpoints if checkpointing - * is enabled. Make sure that your checkpoint directory is secure. - * * @param jssc Java StreamingContext object * @param kinesisAppName Kinesis application name used by the Kinesis Client Library * (KCL) to update DynamoDB @@ -391,6 +384,9 @@ object KinesisUtils { * StorageLevel.MEMORY_AND_DISK_2 is recommended. * @param awsAccessKeyId AWS AccessKeyId (if null, will use DefaultAWSCredentialsProviderChain) * @param awsSecretKey AWS SecretKey (if null, will use DefaultAWSCredentialsProviderChain) + * + * @note The given AWS credentials will get saved in DStream checkpoints if checkpointing + * is enabled. Make sure that your checkpoint directory is secure. */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 905c33834df16..a4d81a680979e 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -221,7 +221,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) assert(collectedData.toSet === testData.toSet) // Verify that the block fetching is skipped when isBlockValid is set to false. - // This is done by using a RDD whose data is only in memory but is set to skip block fetching + // This is done by using an RDD whose data is only in memory but is set to skip block fetching // Using that RDD will throw exception, as it skips block fetching even if the blocks are in // in BlockManager. if (testIsBlockValid) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index e18831382d4d5..3810110099993 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -42,7 +42,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges - /** Return a RDD that brings edges together with their source and destination vertices. */ + /** Return an RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { replicatedVertexView.upgrade(vertices, true, true) replicatedVertexView.edges.partitionsRDD.mapPartitions(_.flatMap { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index f4b00757a8b54..f926984aa6335 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -58,7 +58,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors} * `alpha` is the random reset probability (typically 0.15), `inNbrs[i]` is the set of * neighbors which link to `i` and `outDeg[j]` is the out degree of vertex `j`. * - * Note that this is not the "normalized" PageRank and as a consequence pages that have no + * @note This is not the "normalized" PageRank and as a consequence pages that have no * inlinks will have a PageRank of alpha. */ object PageRank extends Logging { @@ -185,6 +185,13 @@ object PageRank extends Logging { def runParallelPersonalizedPageRank[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], numIter: Int, resetProb: Double = 0.15, sources: Array[VertexId]): Graph[Vector, Double] = { + require(numIter > 0, s"Number of iterations must be greater than 0," + + s" but got ${numIter}") + require(resetProb >= 0 && resetProb <= 1, s"Random reset probability must belong" + + s" to [0, 1], but got ${resetProb}") + require(sources.nonEmpty, s"The list of sources must be non-empty," + + s" but got ${sources.mkString("[", ",", "]")}") + // TODO if one sources vertex id is outside of the int range // we won't be able to store its activations in a sparse vector val zero = Vectors.sparse(sources.size, List()).asBreeze diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 62a22008d0d5d..250b2a882feb5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -357,7 +357,7 @@ static int javaMajorVersion(String javaVersion) { static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) { // TODO: change to the correct directory once the assembly build is changed. File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { + if (new File(sparkHome, "jars").isDirectory()) { libdir = new File(sparkHome, "jars"); checkState(!failIfNotFound || libdir.isDirectory(), "Library directory '%s' does not exist.", diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 7d6693b4cdf5b..792ade8f0bdbd 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -25,7 +25,7 @@ import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.internal.Logging import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.{CommandLineUtils, ShutdownHookManager, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -92,8 +92,11 @@ private[mesos] class MesosClusterDispatcher( } } -private[mesos] object MesosClusterDispatcher extends Logging { - def main(args: Array[String]) { +private[mesos] object MesosClusterDispatcher + extends Logging + with CommandLineUtils { + + override def main(args: Array[String]) { Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) diff --git a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index 11e13441eeba6..ef08502ec8dd6 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -18,23 +18,43 @@ package org.apache.spark.deploy.mesos import scala.annotation.tailrec +import scala.collection.mutable -import org.apache.spark.SparkConf import org.apache.spark.util.{IntParam, Utils} - +import org.apache.spark.SparkConf private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: SparkConf) { - var host = Utils.localHostName() - var port = 7077 - var name = "Spark Cluster" - var webUiPort = 8081 + var host: String = Utils.localHostName() + var port: Int = 7077 + var name: String = "Spark Cluster" + var webUiPort: Int = 8081 + var verbose: Boolean = false var masterUrl: String = _ var zookeeperUrl: Option[String] = None var propertiesFile: String = _ + val confProperties: mutable.HashMap[String, String] = + new mutable.HashMap[String, String]() parse(args.toList) + // scalastyle:on println propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + Utils.updateSparkConfigFromProperties(conf, confProperties) + + // scalastyle:off println + if (verbose) { + MesosClusterDispatcher.printStream.println(s"Using host: $host") + MesosClusterDispatcher.printStream.println(s"Using port: $port") + MesosClusterDispatcher.printStream.println(s"Using webUiPort: $webUiPort") + MesosClusterDispatcher.printStream.println(s"Framework Name: $name") + + Option(propertiesFile).foreach { file => + MesosClusterDispatcher.printStream.println(s"Using properties file: $file") + } + + MesosClusterDispatcher.printStream.println(s"Spark Config properties set:") + conf.getAll.foreach(println) + } @tailrec private def parse(args: List[String]): Unit = args match { @@ -58,9 +78,10 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--master" | "-m") :: value :: tail => if (!value.startsWith("mesos://")) { // scalastyle:off println - System.err.println("Cluster dispatcher only supports mesos (uri begins with mesos://)") + MesosClusterDispatcher.printStream + .println("Cluster dispatcher only supports mesos (uri begins with mesos://)") // scalastyle:on println - System.exit(1) + MesosClusterDispatcher.exitFn(1) } masterUrl = value.stripPrefix("mesos://") parse(tail) @@ -73,28 +94,45 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: propertiesFile = value parse(tail) + case ("--conf") :: value :: tail => + val pair = MesosClusterDispatcher. + parseSparkConfProperty(value) + confProperties(pair._1) = pair._2 + parse(tail) + case ("--help") :: tail => - printUsageAndExit(0) + printUsageAndExit(0) + + case ("--verbose") :: tail => + verbose = true + parse(tail) case Nil => - if (masterUrl == null) { + if (Option(masterUrl).isEmpty) { // scalastyle:off println - System.err.println("--master is required") + MesosClusterDispatcher.printStream.println("--master is required") // scalastyle:on println printUsageAndExit(1) } - case _ => + case value => + // scalastyle:off println + MesosClusterDispatcher.printStream.println(s"Unrecognized option: '${value.head}'") + // scalastyle:on println printUsageAndExit(1) } private def printUsageAndExit(exitCode: Int): Unit = { + val outStream = MesosClusterDispatcher.printStream + // scalastyle:off println - System.err.println( + outStream.println( "Usage: MesosClusterDispatcher [options]\n" + "\n" + "Options:\n" + " -h HOST, --host HOST Hostname to listen on\n" + + " --help Show this help message and exit.\n" + + " --verbose, Print additional debug output.\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + " --webui-port WEBUI_PORT WebUI Port to listen on (default: 8081)\n" + " --name NAME Framework name to show in Mesos UI\n" + @@ -102,8 +140,10 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: " -z --zk ZOOKEEPER Comma delimited URLs for connecting to \n" + " Zookeeper for persistence\n" + " --properties-file FILE Path to a custom Spark properties file.\n" + - " Default is conf/spark-defaults.conf.") + " Default is conf/spark-defaults.conf \n" + + " --conf PROP=VALUE Arbitrary Spark configuration property.\n" + + " Takes precedence over defined properties in properties-file.") // scalastyle:on println - System.exit(exitCode) + MesosClusterDispatcher.exitFn(exitCode) } } diff --git a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 3b96488a129a9..ff60b88c6d533 100644 --- a/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/mesos/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.rest.mesos import java.io.File import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse @@ -62,11 +62,10 @@ private[mesos] class MesosSubmitRequestServlet( private val DEFAULT_CORES = 1.0 private val nextDriverNumber = new AtomicLong(0) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private def newDriverId(submitDate: Date): String = { - "driver-%s-%04d".format( - createDateFormat.format(submitDate), nextDriverNumber.incrementAndGet()) - } + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + private def newDriverId(submitDate: Date): String = + f"driver-${createDateFormat.format(submitDate)}-${nextDriverNumber.incrementAndGet()}%04d" /** * Build a driver description from the fields specified in the submit request. diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 8db1d126d59b8..f384290a6ff25 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -531,13 +531,7 @@ private[spark] class MesosClusterScheduler( .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) - - desc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil.setupContainerBuilderDockerInfo(image, - desc.conf, - taskInfo.getContainerBuilder) - } - + taskInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf)) taskInfo.build } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 5063c1fe988bc..3258b09c06278 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -158,7 +158,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), None, None, sc.conf.getOption("spark.mesos.driver.frameworkId") @@ -213,7 +213,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .format(prefixEnv, runScript) + s" --driver-url $driverURL" + s" --executor-id $taskId" + - s" --hostname ${offer.getHostname}" + + s" --hostname ${executorHostname(offer)}" + s" --cores $numCores" + s" --app-id $appId") } else { @@ -225,7 +225,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" + s" --driver-url $driverURL" + s" --executor-id $taskId" + - s" --hostname ${offer.getHostname}" + + s" --hostname ${executorHostname(offer)}" + s" --cores $numCores" + s" --app-id $appId") command.addUris(CommandInfo.URI.newBuilder().setValue(uri.get).setCache(useFetcherCache)) @@ -418,16 +418,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( .setSlaveId(offer.getSlaveId) .setCommand(createCommand(offer, taskCPUs + extraCoresPerExecutor, taskId)) .setName("Task " + taskId) - taskBuilder.addAllResources(resourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil.setupContainerBuilderDockerInfo( - image, - sc.conf, - taskBuilder.getContainerBuilder - ) - } + taskBuilder.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) tasks(offer.getId) ::= taskBuilder.build() remainingResources(offerId) = resourcesLeft.asJava @@ -658,6 +650,15 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private def numExecutors(): Int = { slaves.values.map(_.taskIDs.size).sum } + + private def executorHostname(offer: Offer): String = { + if (sc.conf.getOption("spark.mesos.network.name").isDefined) { + // The agent's IP is not visible in a CNI container, so we bind to 0.0.0.0 + "0.0.0.0" + } else { + offer.getHostname + } + } } private class Slave(val hostname: String) { diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 09a252f3c74ac..779ffb52299cc 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -77,7 +77,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( sc.sparkUser, sc.appName, sc.conf, - sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.appUIAddress)), + sc.conf.getOption("spark.mesos.driver.webui.url").orElse(sc.ui.map(_.webUrl)), Option.empty, Option.empty, sc.conf.getOption("spark.mesos.driver.frameworkId") @@ -155,14 +155,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setCommand(command) .setData(ByteString.copyFrom(createExecArg())) - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil.setupContainerBuilderDockerInfo( - image, - sc.conf, - executorInfo.getContainerBuilder() - ) - } - + executorInfo.setContainer(MesosSchedulerBackendUtil.containerInfo(sc.conf)) (executorInfo.build(), resourcesAfterMem.asJava) } diff --git a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 3fe06743b8809..a2adb228dc299 100644 --- a/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -17,8 +17,8 @@ package org.apache.spark.scheduler.cluster.mesos -import org.apache.mesos.Protos.{ContainerInfo, Image, Volume} -import org.apache.mesos.Protos.ContainerInfo.DockerInfo +import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Volume} +import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo} import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging @@ -99,67 +99,67 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { .toList } - /** - * Construct a DockerInfo structure and insert it into a ContainerInfo - */ - def addDockerInfo( - container: ContainerInfo.Builder, - image: String, - containerizer: String, - forcePullImage: Boolean = false, - volumes: Option[List[Volume]] = None, - portmaps: Option[List[ContainerInfo.DockerInfo.PortMapping]] = None): Unit = { - - containerizer match { - case "docker" => - container.setType(ContainerInfo.Type.DOCKER) - val docker = ContainerInfo.DockerInfo.newBuilder() - .setImage(image) - .setForcePullImage(forcePullImage) - // TODO (mgummelt): Remove this. Portmaps have no effect, - // as we don't support bridge networking. - portmaps.foreach(_.foreach(docker.addPortMappings)) - container.setDocker(docker) - case "mesos" => - container.setType(ContainerInfo.Type.MESOS) - val imageProto = Image.newBuilder() - .setType(Image.Type.DOCKER) - .setDocker(Image.Docker.newBuilder().setName(image)) - .setCached(!forcePullImage) - container.setMesos(ContainerInfo.MesosInfo.newBuilder().setImage(imageProto)) - case _ => - throw new SparkException( - "spark.mesos.containerizer must be one of {\"docker\", \"mesos\"}") + def containerInfo(conf: SparkConf): ContainerInfo = { + val containerType = if (conf.contains("spark.mesos.executor.docker.image") && + conf.get("spark.mesos.containerizer", "docker") == "docker") { + ContainerInfo.Type.DOCKER + } else { + ContainerInfo.Type.MESOS } - volumes.foreach(_.foreach(container.addVolumes)) + val containerInfo = ContainerInfo.newBuilder() + .setType(containerType) + + conf.getOption("spark.mesos.executor.docker.image").map { image => + val forcePullImage = conf + .getOption("spark.mesos.executor.docker.forcePullImage") + .exists(_.equals("true")) + + val portMaps = conf + .getOption("spark.mesos.executor.docker.portmaps") + .map(parsePortMappingsSpec) + .getOrElse(List.empty) + + if (containerType == ContainerInfo.Type.DOCKER) { + containerInfo.setDocker(dockerInfo(image, forcePullImage, portMaps)) + } else { + containerInfo.setMesos(mesosInfo(image, forcePullImage)) + } + + val volumes = conf + .getOption("spark.mesos.executor.docker.volumes") + .map(parseVolumesSpec) + + volumes.foreach(_.foreach(containerInfo.addVolumes(_))) + } + + conf.getOption("spark.mesos.network.name").map { name => + val info = NetworkInfo.newBuilder().setName(name).build() + containerInfo.addNetworkInfos(info) + } + + containerInfo.build() } - /** - * Setup a docker containerizer from MesosDriverDescription scheduler properties - */ - def setupContainerBuilderDockerInfo( - imageName: String, - conf: SparkConf, - builder: ContainerInfo.Builder): Unit = { - val forcePullImage = conf - .getOption("spark.mesos.executor.docker.forcePullImage") - .exists(_.equals("true")) - val volumes = conf - .getOption("spark.mesos.executor.docker.volumes") - .map(parseVolumesSpec) - val portmaps = conf - .getOption("spark.mesos.executor.docker.portmaps") - .map(parsePortMappingsSpec) - - val containerizer = conf.get("spark.mesos.containerizer", "docker") - addDockerInfo( - builder, - imageName, - containerizer, - forcePullImage = forcePullImage, - volumes = volumes, - portmaps = portmaps) - logDebug("setupContainerDockerInfo: using docker image: " + imageName) + private def dockerInfo( + image: String, + forcePullImage: Boolean, + portMaps: List[ContainerInfo.DockerInfo.PortMapping]): DockerInfo = { + val dockerBuilder = ContainerInfo.DockerInfo.newBuilder() + .setImage(image) + .setForcePullImage(forcePullImage) + portMaps.foreach(dockerBuilder.addPortMappings(_)) + + dockerBuilder.build + } + + private def mesosInfo(image: String, forcePullImage: Boolean): MesosInfo = { + val imageProto = Image.newBuilder() + .setType(Image.Type.DOCKER) + .setDocker(Image.Docker.newBuilder().setName(image)) + .setCached(!forcePullImage) + ContainerInfo.MesosInfo.newBuilder() + .setImage(imageProto) + .build } } diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala new file mode 100644 index 0000000000000..33e7d69d53d38 --- /dev/null +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArgumentsSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.TestPrematureExit + +class MesosClusterDispatcherArgumentsSuite extends SparkFunSuite + with TestPrematureExit { + + test("test if spark config args are passed sucessfully") { + val args = Array[String]("--master", "mesos://localhost:5050", "--conf", "key1=value1", + "--conf", "spark.mesos.key2=value2", "--verbose") + val conf = new SparkConf() + new MesosClusterDispatcherArguments(args, conf) + + assert(conf.getOption("key1").isEmpty) + assert(conf.get("spark.mesos.key2") == "value2") + } + + test("test non conf settings") { + val masterUrl = "mesos://localhost:5050" + val port = "1212" + val zookeeperUrl = "zk://localhost:2181" + val host = "localhost" + val webUiPort = "2323" + val name = "myFramework" + + val args1 = Array("--master", masterUrl, "--verbose", "--name", name) + val args2 = Array("-p", port, "-h", host, "-z", zookeeperUrl) + val args3 = Array("--webui-port", webUiPort) + + val args = args1 ++ args2 ++ args3 + val conf = new SparkConf() + val mesosDispClusterArgs = new MesosClusterDispatcherArguments(args, conf) + + assert(mesosDispClusterArgs.verbose) + assert(mesosDispClusterArgs.confProperties.isEmpty) + assert(mesosDispClusterArgs.host == host) + assert(Option(mesosDispClusterArgs.masterUrl).isDefined) + assert(mesosDispClusterArgs.masterUrl == masterUrl.stripPrefix("mesos://")) + assert(Option(mesosDispClusterArgs.zookeeperUrl).isDefined) + assert(mesosDispClusterArgs.zookeeperUrl == Some(zookeeperUrl)) + assert(mesosDispClusterArgs.name == name) + assert(mesosDispClusterArgs.webUiPort == webUiPort.toInt) + assert(mesosDispClusterArgs.port == port.toInt) + } +} diff --git a/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala new file mode 100644 index 0000000000000..7484e3b83670d --- /dev/null +++ b/mesos/src/test/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.mesos + +import org.apache.spark.SparkFunSuite +import org.apache.spark.deploy.TestPrematureExit + +class MesosClusterDispatcherSuite extends SparkFunSuite + with TestPrematureExit{ + + test("prints usage on empty input") { + testPrematureExit(Array[String](), + "Usage: MesosClusterDispatcher", MesosClusterDispatcher) + } + + test("prints usage with only --help") { + testPrematureExit(Array("--help"), + "Usage: MesosClusterDispatcher", MesosClusterDispatcher) + } + + test("prints error with unrecognized options") { + testPrematureExit(Array("--blarg"), "Unrecognized option: '--blarg'", MesosClusterDispatcher) + testPrematureExit(Array("-bleg"), "Unrecognized option: '-bleg'", MesosClusterDispatcher) + } +} diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index 87d9080de569e..74e5ce227d8b4 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -210,4 +210,30 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi (v.getName, v.getValue)).toMap assert(env.getOrElse("TEST_ENV", null) == "TEST_VAL") } + + test("supports spark.mesos.network.name") { + setScheduler() + + val mem = 1000 + val cpu = 1 + + val response = scheduler.submitDriver( + new MesosDriverDescription("d1", "jar", mem, cpu, true, + command, + Map("spark.mesos.executor.home" -> "test", + "spark.app.name" -> "test", + "spark.mesos.network.name" -> "test-network-name"), + "s1", + new Date())) + + assert(response.success) + + val offer = Utils.createOffer("o1", "s1", mem, cpu) + scheduler.resourceOffers(driver, List(offer).asJava) + + val launchedTasks = Utils.verifyTaskLaunched(driver, "o1") + val networkInfos = launchedTasks.head.getContainer.getNetworkInfosList + assert(networkInfos.size == 1) + assert(networkInfos.get(0).getName == "test-network-name") + } } diff --git a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f73638fda6232..a674da4066456 100644 --- a/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -388,9 +388,6 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite val dockerInfo = containerInfo.getDocker - assert(dockerInfo.getImage == "some_image") - assert(dockerInfo.getForcePullImage) - val portMappings = dockerInfo.getPortMappingsList.asScala assert(portMappings.size == 1) @@ -491,6 +488,22 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(!uris.asScala.head.getCache) } + test("mesos supports spark.mesos.network.name") { + setBackend(Map( + "spark.mesos.network.name" -> "test-network-name" + )) + + val (mem, cpu) = (backend.executorMemory(sc), 4) + + val offer1 = createOffer("o1", "s1", mem, cpu) + backend.resourceOffers(driver, List(offer1).asJava) + + val launchedTasks = verifyTaskLaunched(driver, "o1") + val networkInfos = launchedTasks.head.getContainer.getNetworkInfosList + assert(networkInfos.size == 1) + assert(networkInfos.get(0).getName == "test-network-name") + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def verifyDeclinedOffer(driver: SchedulerDriver, diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 2e4a58dc6291c..22e4ec693b1f7 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -30,7 +30,7 @@ import org.apache.spark.annotation.Since /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @Since("2.0.0") sealed trait Vector extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 252acc156583f..c581fed177273 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.ParamMap abstract class Model[M <: Model[M]] extends Transformer { /** * The parent estimator that produced this model. - * Note: For ensembles' component Models, this value can be null. + * @note For ensembles' component Models, this value can be null. */ @transient var parent: Estimator[M] = _ diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 195a93e086725..f406f8c426d0c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") ( override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) - new Pipeline().setStages(newStages) + new Pipeline(uid).setStages(newStages) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e29d7f48a1d6b..aa92edde7acd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -58,7 +58,8 @@ private[ml] trait PredictorParams extends Params /** * :: DeveloperApi :: - * Abstraction for prediction problems (regression and classification). + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in [[fit()]]. * * @tparam FeaturesType Type of features. * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. @@ -87,7 +88,12 @@ abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) - copyValues(train(dataset).setParent(this)) + + // Cast LabelCol to DoubleType and keep the metadata. + val labelMeta = dataset.schema($(labelCol)).metadata + val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta) + + copyValues(train(casted).setParent(this)) } override def copy(extra: ParamMap): Learner @@ -121,7 +127,7 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index d1b21b16f2342..a3da3067e1b5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -71,7 +71,7 @@ abstract class Classifier[ * and put it in an RDD with strong types. * * @param dataset DataFrame with columns for labels ([[org.apache.spark.sql.types.NumericType]]) - * and features ([[Vector]]). Labels are cast to [[DoubleType]]. + * and features ([[Vector]]). * @param numClasses Number of classes label can take. Labels must be integers in the range * [0, numClasses). * @throws SparkException if any label is not an integer >= 0 @@ -79,7 +79,7 @@ abstract class Classifier[ protected def extractLabeledPoints(dataset: Dataset[_], numClasses: Int): RDD[LabeledPoint] = { require(numClasses > 0, s"Classifier (in extractLabeledPoints) found numClasses =" + s" $numClasses, but requires numClasses > 0.") - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label % 1 == 0 && label >= 0 && label < numClasses, s"Classifier was given" + s" dataset with invalid label $label. Labels must be integers in range" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index bb192ab5f25ab..7424031ed4608 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -207,9 +207,9 @@ class DecisionTreeClassificationModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestClassifier]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestClassifier]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 8bffe0cda0327..52f93f5a6b345 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -43,7 +43,6 @@ import org.apache.spark.sql.types.DoubleType * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * learning algorithm for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * The implementation is based upon: J.H. Friedman. "Stochastic Gradient Boosting." 1999. * @@ -54,6 +53,8 @@ import org.apache.spark.sql.types.DoubleType * based on the loss function, whereas the original gradient boosting method does not. * - We expect to implement TreeBoost in the future: * [https://issues.apache.org/jira/browse/SPARK-4240] + * + * @note Multiclass labels are not currently supported. */ @Since("1.4.0") class GBTClassifier @Since("1.4.0") ( @@ -128,7 +129,7 @@ class GBTClassifier @Since("1.4.0") ( // We copy and modify this from Classifier.extractLabeledPoints since GBT only supports // 2 classes now. This lets us provide a more precise error message. val oldDataset: RDD[LabeledPoint] = - dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => require(label == 0 || label == 1, s"GBTClassifier was given" + s" dataset with invalid label $label. Labels must be in {0,1}; note that" + @@ -169,10 +170,11 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { * Gradient-Boosted Trees (GBTs) (http://en.wikipedia.org/wiki/Gradient_boosting) * model for classification. * It supports binary labels, as well as both continuous and categorical features. - * Note: Multiclass labels are not currently supported. * * @param _trees Decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble. + * + * @note Multiclass labels are not currently supported. */ @Since("1.6.0") class GBTClassificationModel private[ml]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8fdaae04c42ec..d07b4adebb08f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -322,7 +322,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -438,18 +438,14 @@ class LogisticRegression @Since("1.2.0") ( val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { // Remove the L1 penalization on the intercept - val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0) + val isIntercept = $(fitIntercept) && index >= numFeatures * numCoefficientSets if (isIntercept) { 0.0 } else { if (standardizationParam) { regParamL1 } else { - val featureIndex = if ($(fitIntercept)) { - index % numFeaturesPlusIntercept - } else { - index % numFeatures - } + val featureIndex = index / numCoefficientSets // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to // perform this reverse standardization by penalizing each component @@ -466,8 +462,12 @@ class LogisticRegression @Since("1.2.0") ( new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) } - val initialCoefficientsWithIntercept = - Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept) + /* + The coefficients are laid out in column major order during training. Here we initialize + a column major matrix of initial coefficients. + */ + val initialCoefWithInterceptMatrix = + Matrices.zeros(numCoefficientSets, numFeaturesPlusIntercept) val initialModelIsValid = optInitialModel match { case Some(_initialModel) => @@ -486,17 +486,15 @@ class LogisticRegression @Since("1.2.0") ( } if (initialModelIsValid) { - val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray val providedCoef = optInitialModel.get.coefficientMatrix - providedCoef.foreachActive { (row, col, value) => - val flatIndex = row * numFeaturesPlusIntercept + col + providedCoef.foreachActive { (classIndex, featureIndex, value) => // We need to scale the coefficients since they will be trained in the scaled space - initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col) + initialCoefWithInterceptMatrix.update(classIndex, featureIndex, + value * featuresStd(featureIndex)) } if ($(fitIntercept)) { - optInitialModel.get.interceptVector.foreachActive { (index, value) => - val coefIndex = (index + 1) * numFeaturesPlusIntercept - 1 - initialCoefWithInterceptArray(coefIndex) = value + optInitialModel.get.interceptVector.foreachActive { (classIndex, value) => + initialCoefWithInterceptMatrix.update(classIndex, numFeatures, value) } } } else if ($(fitIntercept) && isMultinomial) { @@ -526,8 +524,7 @@ class LogisticRegression @Since("1.2.0") ( val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing val rawMean = rawIntercepts.sum / rawIntercepts.length rawIntercepts.indices.foreach { i => - initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) = - rawIntercepts(i) - rawMean + initialCoefWithInterceptMatrix.update(i, numFeatures, rawIntercepts(i) - rawMean) } } else if ($(fitIntercept)) { /* @@ -543,12 +540,12 @@ class LogisticRegression @Since("1.2.0") ( b = \log{P(1) / P(0)} = \log{count_1 / count_0} }}} */ - initialCoefficientsWithIntercept.toArray(numFeatures) = math.log( - histogram(1) / histogram(0)) + initialCoefWithInterceptMatrix.update(0, numFeatures, + math.log(histogram(1) / histogram(0))) } val states = optimizer.iterations(new CachedDiffFunction(costFun), - initialCoefficientsWithIntercept.asBreeze.toDenseVector) + new BDV[Double](initialCoefWithInterceptMatrix.toArray)) /* Note that in Logistic Regression, the objective history (loss + regularization) @@ -572,19 +569,32 @@ class LogisticRegression @Since("1.2.0") ( /* The coefficients are trained in the scaled space; we're converting them back to the original space. + + Additionally, since the coefficients were laid out in column major order during training + to avoid extra computation, we convert them back to row major before passing them to the + model. + Note that the intercept in scaled space and original space is the same; as a result, no scaling is needed. */ - val rawCoefficients = state.x.toArray.clone() - val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i => - // flatIndex will loop though rawCoefficients, and skip the intercept terms. - val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i - val featureIndex = i % numFeatures - if (featuresStd(featureIndex) != 0.0) { - rawCoefficients(flatIndex) / featuresStd(featureIndex) - } else { - 0.0 + val allCoefficients = state.x.toArray.clone() + val allCoefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, + allCoefficients) + val denseCoefficientMatrix = new DenseMatrix(numCoefficientSets, numFeatures, + new Array[Double](numCoefficientSets * numFeatures), isTransposed = true) + val interceptVec = if ($(fitIntercept) || !isMultinomial) { + Vectors.zeros(numCoefficientSets) + } else { + Vectors.sparse(numCoefficientSets, Seq()) + } + // separate intercepts and coefficients from the combined matrix + allCoefMatrix.foreachActive { (classIndex, featureIndex, value) => + val isIntercept = $(fitIntercept) && (featureIndex == numFeatures) + if (!isIntercept && featuresStd(featureIndex) != 0.0) { + denseCoefficientMatrix.update(classIndex, featureIndex, + value / featuresStd(featureIndex)) } + if (isIntercept) interceptVec.toArray(classIndex) = value } if ($(regParam) == 0.0 && isMultinomial) { @@ -597,17 +607,16 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val coefficientMean = coefficientArray.sum / coefficientArray.length - coefficientArray.indices.foreach { i => coefficientArray(i) -= coefficientMean} + val denseValues = denseCoefficientMatrix.values + val coefficientMean = denseValues.sum / denseValues.length + denseCoefficientMatrix.update(_ - coefficientMean) } - val denseCoefficientMatrix = - new DenseMatrix(numCoefficientSets, numFeatures, coefficientArray, isTransposed = true) // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471 val compressedCoefficientMatrix = if (isMultinomial) { denseCoefficientMatrix } else { - val compressedVector = Vectors.dense(coefficientArray).compressed + val compressedVector = Vectors.dense(denseCoefficientMatrix.values).compressed compressedVector match { case dv: DenseVector => denseCoefficientMatrix case sv: SparseVector => @@ -616,25 +625,13 @@ class LogisticRegression @Since("1.2.0") ( } } - val interceptsArray: Array[Double] = if ($(fitIntercept)) { - Array.tabulate(numCoefficientSets) { i => - val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1 - rawCoefficients(coefIndex) - } - } else { - Array.empty[Double] + // center the intercepts when using multinomial algorithm + if ($(fitIntercept) && isMultinomial) { + val interceptArray = interceptVec.toArray + val interceptMean = interceptArray.sum / interceptArray.length + (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } } - val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) { - // The intercepts are never regularized, so we always center the mean. - val interceptMean = interceptsArray.sum / numClasses - interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean } - Vectors.dense(interceptsArray) - } else if (interceptsArray.length == 1) { - Vectors.dense(interceptsArray) - } else { - Vectors.sparse(numCoefficientSets, Seq()) - } - (compressedCoefficientMatrix, interceptVector.compressed, arrayBuilder.result()) + (compressedCoefficientMatrix, interceptVec.compressed, arrayBuilder.result()) } } @@ -651,7 +648,7 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + model.setSummary(Some(logRegSummary)) } else { model } @@ -697,6 +694,7 @@ class LogisticRegressionModel private[spark] ( /** * A vector of model coefficients for "binomial" logistic regression. If this model was trained * using the "multinomial" family then an exception is thrown. + * * @return Vector */ @Since("2.0.0") @@ -720,6 +718,7 @@ class LogisticRegressionModel private[spark] ( /** * The model intercept for "binomial" logistic regression. If this model was fit with the * "multinomial" family then an exception is thrown. + * * @return Double */ @Since("1.3.0") @@ -791,9 +790,9 @@ class LogisticRegressionModel private[spark] ( } } - private[classification] def setSummary( - summary: LogisticRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[classification] + def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -888,8 +887,7 @@ class LogisticRegressionModel private[spark] ( override def copy(extra: ParamMap): LogisticRegressionModel = { val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, isMultinomial), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { @@ -1179,8 +1177,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. * See http://en.wikipedia.org/wiki/Receiver_operating_characteristic * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") @@ -1188,8 +1186,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Computes the area under the receiver operating characteristic (ROC) curve. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() @@ -1198,8 +1196,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Returns the precision-recall curve, which is a Dataframe containing * two fields recall, precision with (0.0, 1.0) prepended to it. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") @@ -1207,8 +1205,8 @@ class BinaryLogisticRegressionSummary private[classification] ( /** * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val fMeasureByThreshold: DataFrame = { @@ -1220,8 +1218,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the precision. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val precisionByThreshold: DataFrame = { @@ -1233,8 +1231,8 @@ class BinaryLogisticRegressionSummary private[classification] ( * Every possible probability obtained in transforming the dataset are used * as thresholds used in calculating the recall. * - * Note: This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from `LogisticRegression.weightCol`. + * This will change in later Spark versions. */ @Since("1.5.0") @transient lazy val recallByThreshold: DataFrame = { @@ -1395,6 +1393,12 @@ class BinaryLogisticRegressionSummary private[classification] ( * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. * @param multinomial Whether to use multinomial (softmax) or binary loss + * + * @note In order to avoid unnecessary computation during calculation of the gradient updates + * we lay out the coefficients in column major order during training. This allows us to + * perform feature standardization once, while still retaining sequential memory access + * for speed. We convert back to row major order when we create the model, + * since this form is optimal for the matrix operations used for prediction. */ private class LogisticAggregator( bcCoefficients: Broadcast[Vector], @@ -1406,6 +1410,7 @@ private class LogisticAggregator( private val numFeatures = bcFeaturesStd.value.length private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures private val coefficientSize = bcCoefficients.value.size + private val numCoefficientSets = if (multinomial) numClasses else 1 if (multinomial) { require(numClasses == coefficientSize / numFeaturesPlusIntercept, s"The number of " + s"coefficients should be ${numClasses * numFeaturesPlusIntercept} but was $coefficientSize") @@ -1486,23 +1491,25 @@ private class LogisticAggregator( var marginOfLabel = 0.0 var maxMargin = Double.NegativeInfinity - val margins = Array.tabulate(numClasses) { i => - var margin = 0.0 - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - margin += localCoefficients(i * numFeaturesPlusIntercept + index) * - value / localFeaturesStd(index) - } + val margins = new Array[Double](numClasses) + features.foreachActive { (index, value) => + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + margins(j) += localCoefficients(index * numClasses + j) * stdValue + j += 1 } - + } + var i = 0 + while (i < numClasses) { if (fitIntercept) { - margin += localCoefficients(i * numFeaturesPlusIntercept + numFeatures) + margins(i) += localCoefficients(numClasses * numFeatures + i) } - if (i == label.toInt) marginOfLabel = margin - if (margin > maxMargin) { - maxMargin = margin + if (i == label.toInt) marginOfLabel = margins(i) + if (margins(i) > maxMargin) { + maxMargin = margins(i) } - margin + i += 1 } /** @@ -1510,33 +1517,39 @@ private class LogisticAggregator( * We address this by subtracting maxMargin from all the margins, so it's guaranteed * that all of the new margins will be smaller than zero to prevent arithmetic overflow. */ + val multipliers = new Array[Double](numClasses) val sum = { var temp = 0.0 - if (maxMargin > 0) { - for (i <- 0 until numClasses) { - margins(i) -= maxMargin - temp += math.exp(margins(i)) - } - } else { - for (i <- 0 until numClasses) { - temp += math.exp(margins(i)) - } + var i = 0 + while (i < numClasses) { + if (maxMargin > 0) margins(i) -= maxMargin + val exp = math.exp(margins(i)) + temp += exp + multipliers(i) = exp + i += 1 } temp } - for (i <- 0 until numClasses) { - val multiplier = math.exp(margins(i)) / sum - { - if (label == i) 1.0 else 0.0 - } - features.foreachActive { (index, value) => - if (localFeaturesStd(index) != 0.0 && value != 0.0) { - localGradientArray(i * numFeaturesPlusIntercept + index) += - weight * multiplier * value / localFeaturesStd(index) + margins.indices.foreach { i => + multipliers(i) = multipliers(i) / sum - (if (label == i) 1.0 else 0.0) + } + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + val stdValue = value / localFeaturesStd(index) + var j = 0 + while (j < numClasses) { + localGradientArray(index * numClasses + j) += + weight * multipliers(j) * stdValue + j += 1 } } - if (fitIntercept) { - localGradientArray(i * numFeaturesPlusIntercept + numFeatures) += weight * multiplier + } + if (fitIntercept) { + var i = 0 + while (i < numClasses) { + localGradientArray(numFeatures * numClasses + i) += weight * multipliers(i) + i += 1 } } @@ -1607,12 +1620,12 @@ private class LogisticAggregator( lossSum / weightSum } - def gradient: Vector = { + def gradient: Matrix = { require(weightSum > 0.0, s"The effective number of instances should be " + s"greater than 0.0, but $weightSum.") val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) - result + new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, result.toArray) } } @@ -1637,6 +1650,8 @@ private class LogisticCostFun( val bcCoeffs = instances.context.broadcast(coeffs) val featuresStd = bcFeaturesStd.value val numFeatures = featuresStd.length + val numCoefficientSets = if (multinomial) numClasses else 1 + val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures val logisticAggregator = { val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) @@ -1648,28 +1663,25 @@ private class LogisticCostFun( )(seqOp, combOp, aggregationDepth) } - val totalGradientArray = logisticAggregator.gradient.toArray + val totalGradientMatrix = logisticAggregator.gradient + val coefMatrix = new DenseMatrix(numCoefficientSets, numFeaturesPlusIntercept, coeffs.toArray) // regVal is the sum of coefficients squares excluding intercept for L2 regularization. val regVal = if (regParamL2 == 0.0) { 0.0 } else { var sum = 0.0 - coeffs.foreachActive { case (index, value) => + coefMatrix.foreachActive { case (classIndex, featureIndex, value) => // We do not apply regularization to the intercepts - val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) == 0) + val isIntercept = fitIntercept && (featureIndex == numFeatures) if (!isIntercept) { // The following code will compute the loss of the regularization; also // the gradient of the regularization, and add back to totalGradientArray. sum += { if (standardization) { - totalGradientArray(index) += regParamL2 * value + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * value) value * value } else { - val featureIndex = if (fitIntercept) { - index % (numFeatures + 1) - } else { - index % numFeatures - } if (featuresStd(featureIndex) != 0.0) { // If `standardization` is false, we still standardize the data // to improve the rate of convergence; as a result, we have to @@ -1677,7 +1689,8 @@ private class LogisticCostFun( // differently to get effectively the same objective function when // the training dataset is not standardized. val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex)) - totalGradientArray(index) += regParamL2 * temp + val gradValue = totalGradientMatrix(classIndex, featureIndex) + totalGradientMatrix.update(classIndex, featureIndex, gradValue + regParamL2 * temp) value * temp } else { 0.0 @@ -1690,6 +1703,6 @@ private class LogisticCostFun( } bcCoeffs.destroy(blocking = false) - (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) + (logisticAggregator.loss + regVal, new BDV(totalGradientMatrix.toArray)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 994ed993c99df..f1a7676c74b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -76,7 +76,7 @@ class NaiveBayes @Since("1.5.0") ( extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] with NaiveBayesParams with DefaultParamsWritable { - import NaiveBayes.{Bernoulli, Multinomial} + import NaiveBayes._ @Since("1.5.0") def this() = this(Identifiable.randomUID("nb")) @@ -110,21 +110,20 @@ class NaiveBayes @Since("1.5.0") ( @Since("2.1.0") def setWeightCol(value: String): this.type = set(weightCol, value) + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { + trainWithLabelCheck(dataset, positiveLabel = true) + } + /** * ml assumes input labels in range [0, numClasses). But this implementation * is also called by mllib NaiveBayes which allows other kinds of input labels - * such as {-1, +1}. Here we use this parameter to switch between different processing logic. - * It should be removed when we remove mllib NaiveBayes. + * such as {-1, +1}. `positiveLabel` is used to determine whether the label + * should be checked and it should be removed when we remove mllib NaiveBayes. */ - private[spark] var isML: Boolean = true - - private[spark] def setIsML(isML: Boolean): this.type = { - this.isML = isML - this - } - - override protected def train(dataset: Dataset[_]): NaiveBayesModel = { - if (isML) { + private[spark] def trainWithLabelCheck( + dataset: Dataset[_], + positiveLabel: Boolean): NaiveBayesModel = { + if (positiveLabel) { val numClasses = getNumClasses(dataset) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -133,28 +132,9 @@ class NaiveBayes @Since("1.5.0") ( } } - val requireNonnegativeValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(_ >= 0.0), - s"Naive Bayes requires nonnegative feature values but found $v.") - } - - val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { - val values = v match { - case sv: SparseVector => sv.values - case dv: DenseVector => dv.values - } - - require(values.forall(v => v == 0.0 || v == 1.0), - s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") - } - + val modelTypeValue = $(modelType) val requireValues: Vector => Unit = { - $(modelType) match { + modelTypeValue match { case Multinomial => requireNonnegativeValues case Bernoulli => @@ -171,7 +151,7 @@ class NaiveBayes @Since("1.5.0") ( // Aggregates term frequencies per label. // TODO: Calling aggregateByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd + val aggregated = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd .map { row => (row.getDouble(0), (row.getDouble(1), row.getAs[Vector](2))) }.aggregateByKey[(Double, DenseVector)]((0.0, Vectors.zeros(numFeatures).toDense))( seqOp = { @@ -226,13 +206,33 @@ class NaiveBayes @Since("1.5.0") ( @Since("1.6.0") object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { /** String name for multinomial model type. */ - private[spark] val Multinomial: String = "multinomial" + private[classification] val Multinomial: String = "multinomial" /** String name for Bernoulli model type. */ - private[spark] val Bernoulli: String = "bernoulli" + private[classification] val Bernoulli: String = "bernoulli" /* Set of modelTypes that NaiveBayes supports */ - private[spark] val supportedModelTypes = Set(Multinomial, Bernoulli) + private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli) + + private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(_ >= 0.0), + s"Naive Bayes requires nonnegative feature values but found $v.") + } + + private[NaiveBayes] def requireZeroOneBernoulliValues(v: Vector): Unit = { + val values = v match { + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values + } + + require(values.forall(v => v == 0.0 || v == 1.0), + s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") + } @Since("1.6.0") override def load(path: String): NaiveBayes = super.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 2718dd93dcb5a..e6ca3aedffd9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -94,8 +94,8 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { - val copied = new BisectingKMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -131,8 +131,8 @@ class BisectingKMeansModel private[ml] ( private var trainingSummary: Option[BisectingKMeansSummary] = None - private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -264,7 +264,7 @@ class BisectingKMeans @Since("2.0.0") ( val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this)) val summary = new BisectingKMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 8fac63fefbb55..92d0b7d085f12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -89,8 +89,8 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { - val copied = new GaussianMixtureModel(uid, weights, gaussians) - copyValues(copied, extra).setParent(this.parent) + val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) + copied.setSummary(trainingSummary).setParent(this.parent) } @Since("2.0.0") @@ -149,8 +149,8 @@ class GaussianMixtureModel private[ml] ( private var trainingSummary: Option[GaussianMixtureSummary] = None - private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = { + this.trainingSummary = summary this } @@ -267,9 +267,9 @@ object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("2.0.0") @Experimental @@ -339,7 +339,7 @@ class GaussianMixture @Since("2.0.0") ( .setParent(this) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logNumFeatures(model.gaussians.head.mean.size) instr.logSuccess(model) model diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 85bb8c93b3fa9..152bd13b7a17a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Common params for KMeans and KMeansModel @@ -108,8 +109,8 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { - val copied = new KMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new KMeansModel(uid, parentModel), extra) + copied.setSummary(trainingSummary).setParent(this.parent) } /** @group setParam */ @@ -163,8 +164,8 @@ class KMeansModel private[ml] ( private var trainingSummary: Option[KMeansSummary] = None - private[clustering] def setSummary(summary: KMeansSummary): this.type = { - this.trainingSummary = Some(summary) + private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = { + this.trainingSummary = summary this } @@ -231,10 +232,7 @@ object KMeansModel extends MLReadable[KMeansModel] { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) val dataPath = new Path(path, "data").toString - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - - val clusterCenters = if (major.toInt >= 2) { + val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) { val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data] data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML) } else { @@ -326,7 +324,7 @@ class KMeans @Since("1.5.0") ( val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) val summary = new KMeansSummary( model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) - model.setSummary(summary) + model.setSummary(Some(summary)) instr.logSuccess(model) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index d0385e220e1e2..653fa41124f88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -42,69 +42,80 @@ private[feature] trait ChiSqSelectorParams extends Params with HasFeaturesCol with HasOutputCol with HasLabelCol { /** - * Number of features that selector will select (ordered by statistic value descending). If the + * Number of features that selector will select, ordered by ascending p-value. If the * number of features is less than numTopFeatures, then this will select all features. - * Only applicable when selectorType = "kbest". + * Only applicable when selectorType = "numTopFeatures". * The default value of numTopFeatures is 50. * * @group param */ + @Since("1.6.0") final val numTopFeatures = new IntParam(this, "numTopFeatures", - "Number of features that selector will select, ordered by statistics value descending. If the" + + "Number of features that selector will select, ordered by ascending p-value. If the" + " number of features is < numTopFeatures, then this will select all features.", ParamValidators.gtEq(1)) setDefault(numTopFeatures -> 50) /** @group getParam */ + @Since("1.6.0") def getNumTopFeatures: Int = $(numTopFeatures) /** * Percentile of features that selector will select, ordered by statistics value descending. * Only applicable when selectorType = "percentile". * Default value is 0.1. + * @group param */ + @Since("2.1.0") final val percentile = new DoubleParam(this, "percentile", - "Percentile of features that selector will select, ordered by statistics value descending.", + "Percentile of features that selector will select, ordered by ascending p-value.", ParamValidators.inRange(0, 1)) setDefault(percentile -> 0.1) /** @group getParam */ + @Since("2.1.0") def getPercentile: Double = $(percentile) /** * The highest p-value for features to be kept. * Only applicable when selectorType = "fpr". * Default value is 0.05. + * @group param */ - final val alpha = new DoubleParam(this, "alpha", "The highest p-value for features to be kept.", + final val fpr = new DoubleParam(this, "fpr", "The highest p-value for features to be kept.", ParamValidators.inRange(0, 1)) - setDefault(alpha -> 0.05) + setDefault(fpr -> 0.05) /** @group getParam */ - def getAlpha: Double = $(alpha) + def getFpr: Double = $(fpr) /** * The selector type of the ChisqSelector. - * Supported options: "kbest" (default), "percentile" and "fpr". + * Supported options: "numTopFeatures" (default), "percentile", "fpr". + * @group param */ + @Since("2.1.0") final val selectorType = new Param[String](this, "selectorType", "The selector type of the ChisqSelector. " + - "Supported options: kbest (default), percentile and fpr.", - ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes.toArray)) - setDefault(selectorType -> OldChiSqSelector.KBest) + "Supported options: " + OldChiSqSelector.supportedSelectorTypes.mkString(", "), + ParamValidators.inArray[String](OldChiSqSelector.supportedSelectorTypes)) + setDefault(selectorType -> OldChiSqSelector.NumTopFeatures) /** @group getParam */ + @Since("2.1.0") def getSelectorType: String = $(selectorType) } /** * Chi-Squared feature selection, which selects categorical features to use for predicting a * categorical label. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.6.0") final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: String) @@ -113,10 +124,6 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str @Since("1.6.0") def this() = this(Identifiable.randomUID("chiSqSelector")) - /** @group setParam */ - @Since("2.1.0") - def setSelectorType(value: String): this.type = set(selectorType, value) - /** @group setParam */ @Since("1.6.0") def setNumTopFeatures(value: Int): this.type = set(numTopFeatures, value) @@ -127,7 +134,11 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str /** @group setParam */ @Since("2.1.0") - def setAlpha(value: Double): this.type = set(alpha, value) + def setFpr(value: Double): this.type = set(fpr, value) + + /** @group setParam */ + @Since("2.1.0") + def setSelectorType(value: String): this.type = set(selectorType, value) /** @group setParam */ @Since("1.6.0") @@ -153,15 +164,15 @@ final class ChiSqSelector @Since("1.6.0") (@Since("1.6.0") override val uid: Str .setSelectorType($(selectorType)) .setNumTopFeatures($(numTopFeatures)) .setPercentile($(percentile)) - .setAlpha($(alpha)) + .setFpr($(fpr)) val model = selector.fit(input) copyValues(new ChiSqSelectorModel(uid, model).setParent(this)) } @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - val otherPairs = OldChiSqSelector.supportedTypeAndParamPairs.filter(_._1 != $(selectorType)) - otherPairs.foreach { case (_, paramName: String) => + val otherPairs = OldChiSqSelector.supportedSelectorTypes.filter(_ != $(selectorType)) + otherPairs.foreach { paramName: String => if (isSet(getParam(paramName))) { logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 6386dd8a10801..46a0730f5ddb8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -44,7 +44,8 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * @group param */ final val minDocFreq = new IntParam( - this, "minDocFreq", "minimum number of documents in which a term should appear for filtering") + this, "minDocFreq", "minimum number of documents in which a term should appear for filtering" + + " (>= 0)", ParamValidators.gtEq(0)) setDefault(minDocFreq -> 0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala index 333a8c364a884..eb117c40eea3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala @@ -40,7 +40,7 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol { * @group param */ final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" + - "increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + + " increasing dimensionality lowers the false negative rate, and decreasing dimensionality" + " improves the running performance", ParamValidators.gt(0)) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 28cbe1cb01e9a..ccfb0ce8f85ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -85,7 +85,8 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H *

* * For the case $E_{max} == E_{min}$, $Rescaled(e_i) = 0.5 * (max + min)$. - * Note that since zero values will probably be transformed to non-zero values, output of the + * + * @note Since zero values will probably be transformed to non-zero values, output of the * transformer will be DenseVector even for sparse input. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index e8e28ba29c841..ea401216aec7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -36,7 +36,8 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The last category is not included by default (configurable via [[OneHotEncoder!.dropLast]] * because it makes the vector entries sum up to one, and hence linearly dependent. * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. - * Note that this is different from scikit-learn's OneHotEncoder, which keeps all categories. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. * The output vectors are sparse. * * @see [[StringIndexer]] for converting categorical values into category indices diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 6b913480fdc28..6e08bf059124c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.VersionUtils.majorVersion /** * Params for [[PCA]] and [[PCAModel]]. @@ -44,7 +45,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC * The number of principal components. * @group param */ - final val k: IntParam = new IntParam(this, "k", "the number of principal components") + final val k: IntParam = new IntParam(this, "k", "the number of principal components (> 0)", + ParamValidators.gt(0)) /** @group getParam */ def getK: Int = $(k) @@ -140,8 +142,9 @@ class PCAModel private[ml] ( /** * Transform a vector by computed Principal Components. - * NOTE: Vectors to be transformed must be the same length - * as the source vectors given to [[PCA.fit()]]. + * + * @note Vectors to be transformed must be the same length as the source vectors given + * to [[PCA.fit()]]. */ @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -203,11 +206,8 @@ object PCAModel extends MLReadable[PCAModel] { override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val versionRegex = "([0-9]+)\\.(.+)".r - val versionRegex(major, _) = metadata.sparkVersion - val dataPath = new Path(path, "data").toString - val model = if (major.toInt >= 2) { + val model = if (majorVersion(metadata.sparkVersion) >= 2) { val Row(pc: DenseMatrix, explainedVariance: DenseVector) = sparkSession.read.parquet(dataPath) .select("pc", "explainedVariance") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 666070037cdd8..0ced21365ff6f 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -28,7 +28,10 @@ import org.apache.spark.sql.types.{ArrayType, StringType, StructType} /** * A feature transformer that filters out stop words from input. - * Note: null values from input array are preserved unless adding null to stopWords explicitly. + * + * @note null values from input array are preserved unless adding null to stopWords + * explicitly. + * * @see [[http://en.wikipedia.org/wiki/Stop_words]] */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 80fe46796f807..8b155f00017cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -113,11 +113,11 @@ object StringIndexer extends DefaultParamsReadable[StringIndexer] { /** * Model fitted by [[StringIndexer]]. * - * NOTE: During transformation, if the input column does not exist, + * @param labels Ordered list of labels, corresponding to indices to be assigned. + * + * @note During transformation, if the input column does not exist, * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. - * - * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Since("1.4.0") class StringIndexerModel ( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index d53f3df514dff..3ed08c983d561 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -43,7 +43,8 @@ private[feature] trait Word2VecBase extends Params * @group param */ final val vectorSize = new IntParam( - this, "vectorSize", "the dimension of codes after transforming from words") + this, "vectorSize", "the dimension of codes after transforming from words (> 0)", + ParamValidators.gt(0)) setDefault(vectorSize -> 100) /** @group getParam */ @@ -55,7 +56,8 @@ private[feature] trait Word2VecBase extends Params * @group expertParam */ final val windowSize = new IntParam( - this, "windowSize", "the window size (context words from [-window, window])") + this, "windowSize", "the window size (context words from [-window, window]) (> 0)", + ParamValidators.gt(0)) setDefault(windowSize -> 5) /** @group expertGetParam */ @@ -67,7 +69,8 @@ private[feature] trait Word2VecBase extends Params * @group param */ final val numPartitions = new IntParam( - this, "numPartitions", "number of partitions for sentences of words") + this, "numPartitions", "number of partitions for sentences of words (> 0)", + ParamValidators.gt(0)) setDefault(numPartitions -> 1) /** @group getParam */ @@ -80,7 +83,7 @@ private[feature] trait Word2VecBase extends Params * @group param */ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + - "appear to be included in the word2vec model's vocabulary") + "appear to be included in the word2vec model's vocabulary (>= 0)", ParamValidators.gtEq(0)) setDefault(minCount -> 5) /** @group getParam */ @@ -95,7 +98,7 @@ private[feature] trait Word2VecBase extends Params */ final val maxSentenceLength = new IntParam(this, "maxSentenceLength", "Maximum length " + "(in words) of each sentence in the input data. Any sentence longer than this threshold will " + - "be divided into chunks up to the size.") + "be divided into chunks up to the size (> 0)", ParamValidators.gt(0)) setDefault(maxSentenceLength -> 1000) /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala index 2f5299b010223..96fd0d18b5ae9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala @@ -16,9 +16,10 @@ */ package org.apache.spark.ml.optim +import scala.collection.mutable + import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} -import scala.collection.mutable import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors} import org.apache.spark.mllib.linalg.CholeskyDecomposition @@ -57,7 +58,7 @@ private[ml] sealed trait NormalEquationSolver { */ private[ml] class CholeskySolver extends NormalEquationSolver { - def solve( + override def solve( bBar: Double, bbBar: Double, abBar: DenseVector, @@ -80,7 +81,7 @@ private[ml] class QuasiNewtonSolver( tol: Double, l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver { - def solve( + override def solve( bBar: Double, bbBar: Double, abBar: DenseVector, @@ -156,7 +157,7 @@ private[ml] class QuasiNewtonSolver( * Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible * (singular). */ -class SingularMatrixException(message: String, cause: Throwable) +private[spark] class SingularMatrixException(message: String, cause: Throwable) extends IllegalArgumentException(message, cause) { def this(message: String) = this(message, null) diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala index 90c24e1b590ea..56ab9675700a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala @@ -47,7 +47,7 @@ private[ml] class WeightedLeastSquaresModel( * formulation: * * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,, - * + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^ + * + lambda / delta (1/2 (1 - alpha) sum,,j,, (sigma,,j,, x,,j,,)^2^ * + alpha sum,,j,, abs(sigma,,j,, x,,j,,)), * * where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter, @@ -91,7 +91,7 @@ private[ml] class WeightedLeastSquares( require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0, s"elasticNetParam must be in [0, 1]: $elasticNetParam") require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter") - require(tol > 0, s"tol must be greater than zero: $tol") + require(tol >= 0.0, s"tol must be >= 0, but was set to $tol") /** * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 9245931b27ca6..96206e0b7ad88 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -533,7 +533,7 @@ trait Params extends Identifiable with Serializable { * Returns all params sorted by their names. The default implementation uses Java reflection to * list all public methods that have no arguments and return [[Param]]. * - * Note: Developer should not use this method in constructor because we cannot guarantee that + * @note Developer should not use this method in constructor because we cannot guarantee that * this variable gets initialized before other params. */ lazy val params: Array[Param[_]] = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala new file mode 100644 index 0000000000000..aacb41ee2659b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTClassificationWrapper.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier} +import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class GBTClassifierWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + import GBTClassifierWrapper._ + + private val gbtcModel: GBTClassificationModel = + pipeline.stages(1).asInstanceOf[GBTClassificationModel] + + lazy val numFeatures: Int = gbtcModel.numFeatures + lazy val featureImportances: Vector = gbtcModel.featureImportances + lazy val numTrees: Int = gbtcModel.getNumTrees + lazy val treeWeights: Array[Double] = gbtcModel.treeWeights + + def summary: String = gbtcModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(gbtcModel.getFeaturesCol) + .drop(gbtcModel.getLabelCol) + } + + override def write: MLWriter = new + GBTClassifierWrapper.GBTClassifierWrapperWriter(this) +} + +private[r] object GBTClassifierWrapper extends MLReadable[GBTClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + maxIter: Int, + stepSize: Double, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + lossType: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): GBTClassifierWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + .setForceIndexLabel(true) + checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) + + // assemble and fit the pipeline + val rfc = new GBTClassifier() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setMaxIter(maxIter) + .setStepSize(stepSize) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setLossType(lossType) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfc, idxToStr)) + .fit(data) + + new GBTClassifierWrapper(pipeline, formula, features) + } + + override def read: MLReader[GBTClassifierWrapper] = new GBTClassifierWrapperReader + + override def load(path: String): GBTClassifierWrapper = super.load(path) + + class GBTClassifierWrapperWriter(instance: GBTClassifierWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GBTClassifierWrapperReader extends MLReader[GBTClassifierWrapper] { + + override def load(path: String): GBTClassifierWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new GBTClassifierWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala new file mode 100644 index 0000000000000..585077588eb9b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GBTRegressionWrapper.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.hadoop.fs.Path +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} + +private[r] class GBTRegressorWrapper private ( + val pipeline: PipelineModel, + val formula: String, + val features: Array[String]) extends MLWritable { + + private val gbtrModel: GBTRegressionModel = + pipeline.stages(1).asInstanceOf[GBTRegressionModel] + + lazy val numFeatures: Int = gbtrModel.numFeatures + lazy val featureImportances: Vector = gbtrModel.featureImportances + lazy val numTrees: Int = gbtrModel.getNumTrees + lazy val treeWeights: Array[Double] = gbtrModel.treeWeights + + def summary: String = gbtrModel.toDebugString + + def transform(dataset: Dataset[_]): DataFrame = { + pipeline.transform(dataset).drop(gbtrModel.getFeaturesCol) + } + + override def write: MLWriter = new + GBTRegressorWrapper.GBTRegressorWrapperWriter(this) +} + +private[r] object GBTRegressorWrapper extends MLReadable[GBTRegressorWrapper] { + def fit( // scalastyle:ignore + data: DataFrame, + formula: String, + maxDepth: Int, + maxBins: Int, + maxIter: Int, + stepSize: Double, + minInstancesPerNode: Int, + minInfoGain: Double, + checkpointInterval: Int, + lossType: String, + seed: String, + subsamplingRate: Double, + maxMemoryInMB: Int, + cacheNodeIds: Boolean): GBTRegressorWrapper = { + + val rFormula = new RFormula() + .setFormula(formula) + RWrapperUtils.checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) + + // get feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + + // assemble and fit the pipeline + val rfr = new GBTRegressor() + .setMaxDepth(maxDepth) + .setMaxBins(maxBins) + .setMaxIter(maxIter) + .setStepSize(stepSize) + .setMinInstancesPerNode(minInstancesPerNode) + .setMinInfoGain(minInfoGain) + .setCheckpointInterval(checkpointInterval) + .setLossType(lossType) + .setSubsamplingRate(subsamplingRate) + .setMaxMemoryInMB(maxMemoryInMB) + .setCacheNodeIds(cacheNodeIds) + .setFeaturesCol(rFormula.getFeaturesCol) + if (seed != null && seed.length > 0) rfr.setSeed(seed.toLong) + + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, rfr)) + .fit(data) + + new GBTRegressorWrapper(pipeline, formula, features) + } + + override def read: MLReader[GBTRegressorWrapper] = new GBTRegressorWrapperReader + + override def load(path: String): GBTRegressorWrapper = super.load(path) + + class GBTRegressorWrapperWriter(instance: GBTRegressorWrapper) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + + val rMetadata = ("class" -> instance.getClass.getName) ~ + ("formula" -> instance.formula) ~ + ("features" -> instance.features.toSeq) + val rMetadataJson: String = compact(render(rMetadata)) + + sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) + instance.pipeline.save(pipelinePath) + } + } + + class GBTRegressorWrapperReader extends MLReader[GBTRegressorWrapper] { + + override def load(path: String): GBTRegressorWrapper = { + implicit val format = DefaultFormats + val rMetadataPath = new Path(path, "rMetadata").toString + val pipelinePath = new Path(path, "pipeline").toString + val pipeline = PipelineModel.load(pipelinePath) + + val rMetadataStr = sc.textFile(rMetadataPath, 1).first() + val rMetadata = parse(rMetadataStr) + val formula = (rMetadata \ "formula").extract[String] + val features = (rMetadata \ "features").extract[Array[String]] + + new GBTRegressorWrapper(pipeline, formula, features) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index b1bb577e1ffe4..add4d49110d16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -23,11 +23,17 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.regression._ +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ private[r] class GeneralizedLinearRegressionWrapper private ( val pipeline: PipelineModel, @@ -42,6 +48,8 @@ private[r] class GeneralizedLinearRegressionWrapper private ( val rNumIterations: Int, val isLoaded: Boolean = false) extends MLWritable { + import GeneralizedLinearRegressionWrapper._ + private val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] @@ -52,7 +60,16 @@ private[r] class GeneralizedLinearRegressionWrapper private ( def residuals(residualsType: String): DataFrame = glm.summary.residuals(residualsType) def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(glm.getFeaturesCol) + if (rFamily == "binomial") { + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_PROB_COL) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(glm.getFeaturesCol) + .drop(glm.getLabelCol) + } else { + pipeline.transform(dataset) + .drop(glm.getFeaturesCol) + } } override def write: MLWriter = @@ -62,6 +79,10 @@ private[r] class GeneralizedLinearRegressionWrapper private ( private[r] object GeneralizedLinearRegressionWrapper extends MLReadable[GeneralizedLinearRegressionWrapper] { + val PREDICTED_LABEL_PROB_COL = "pred_label_prob" + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( formula: String, data: DataFrame, @@ -71,9 +92,9 @@ private[r] object GeneralizedLinearRegressionWrapper maxIter: Int, weightCol: String, regParam: Double): GeneralizedLinearRegressionWrapper = { - val rFormula = new RFormula() - .setFormula(formula) - RWrapperUtils.checkDataColumns(rFormula, data) + val rFormula = new RFormula().setFormula(formula) + if (family == "binomial") rFormula.setForceIndexLabel(true) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema val schema = rFormulaModel.transform(data).schema @@ -90,9 +111,28 @@ private[r] object GeneralizedLinearRegressionWrapper .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) - val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, glr)) - .fit(data) + .setLabelCol(rFormula.getLabelCol) + val pipeline = if (family == "binomial") { + // Convert prediction from probability to label index. + val probToPred = new ProbabilityToPrediction() + .setInputCol(PREDICTED_LABEL_PROB_COL) + .setOutputCol(PREDICTED_LABEL_INDEX_COL) + // Convert prediction from label index to original label. + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + + new Pipeline() + .setStages(Array(rFormulaModel, glr.setPredictionCol(PREDICTED_LABEL_PROB_COL), + probToPred, idxToStr)) + .fit(data) + } else { + new Pipeline().setStages(Array(rFormulaModel, glr)).fit(data) + } val glm: GeneralizedLinearRegressionModel = pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] @@ -200,3 +240,27 @@ private[r] object GeneralizedLinearRegressionWrapper } } } + +/** + * This utility transformer converts the predicted value of GeneralizedLinearRegressionModel + * with "binomial" family from probability to prediction according to threshold 0.5. + */ +private[r] class ProbabilityToPrediction private[r] (override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("probToPred")) + + def setInputCol(value: String): this.type = set(inputCol, value) + + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields :+ StructField($(outputCol), DoubleType)) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + dataset.withColumn($(outputCol), round(col($(inputCol)))) + } + + override def copy(extra: ParamMap): ProbabilityToPrediction = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index 2193eb80e9fdd..d34de30931143 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -24,19 +24,29 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier} +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} import org.apache.spark.sql.{DataFrame, Dataset} private[r] class MultilayerPerceptronClassifierWrapper private ( - val pipeline: PipelineModel, - val labelCount: Long, - val layers: Array[Int], - val weights: Array[Double] + val pipeline: PipelineModel ) extends MLWritable { + import MultilayerPerceptronClassifierWrapper._ + + val mlpModel: MultilayerPerceptronClassificationModel = + pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] + + val weights: Array[Double] = mlpModel.weights.toArray + val layers: Array[Int] = mlpModel.layers + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) + .drop(mlpModel.getFeaturesCol) + .drop(mlpModel.getLabelCol) + .drop(PREDICTED_LABEL_INDEX_COL) } /** @@ -49,10 +59,12 @@ private[r] class MultilayerPerceptronClassifierWrapper private ( private[r] object MultilayerPerceptronClassifierWrapper extends MLReadable[MultilayerPerceptronClassifierWrapper] { + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" val PREDICTED_LABEL_COL = "prediction" def fit( data: DataFrame, + formula: String, blockSize: Int, layers: Array[Int], solver: String, @@ -62,8 +74,13 @@ private[r] object MultilayerPerceptronClassifierWrapper seed: String, initialWeights: Array[Double] ): MultilayerPerceptronClassifierWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + .setForceIndexLabel(true) + checkDataColumns(rFormula, data) + val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema - val schema = data.schema + val (_, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val mlp = new MultilayerPerceptronClassifier() @@ -73,25 +90,25 @@ private[r] object MultilayerPerceptronClassifierWrapper .setMaxIter(maxIter) .setTol(tol) .setStepSize(stepSize) - .setPredictionCol(PREDICTED_LABEL_COL) + .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt) if (initialWeights != null) { require(initialWeights.length > 0) mlp.setInitialWeights(Vectors.dense(initialWeights)) } + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() - .setStages(Array(mlp)) + .setStages(Array(rFormulaModel, mlp, idxToStr)) .fit(data) - val multilayerPerceptronClassificationModel: MultilayerPerceptronClassificationModel = - pipeline.stages.head.asInstanceOf[MultilayerPerceptronClassificationModel] - - val weights = multilayerPerceptronClassificationModel.weights.toArray - val layersFromPipeline = multilayerPerceptronClassificationModel.layers - val labelCount = data.select("label").distinct().count() - - new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layersFromPipeline, weights) + new MultilayerPerceptronClassifierWrapper(pipeline) } /** @@ -107,17 +124,10 @@ private[r] object MultilayerPerceptronClassifierWrapper override def load(path: String): MultilayerPerceptronClassifierWrapper = { implicit val format = DefaultFormats - val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString - val rMetadataStr = sc.textFile(rMetadataPath, 1).first() - val rMetadata = parse(rMetadataStr) - val labelCount = (rMetadata \ "labelCount").extract[Long] - val layers = (rMetadata \ "layers").extract[Array[Int]] - val weights = (rMetadata \ "weights").extract[Array[Double]] - val pipeline = PipelineModel.load(pipelinePath) - new MultilayerPerceptronClassifierWrapper(pipeline, labelCount, layers, weights) + new MultilayerPerceptronClassifierWrapper(pipeline) } } @@ -128,10 +138,7 @@ private[r] object MultilayerPerceptronClassifierWrapper val rMetadataPath = new Path(path, "rMetadata").toString val pipelinePath = new Path(path, "pipeline").toString - val rMetadata = ("class" -> instance.getClass.getName) ~ - ("labelCount" -> instance.labelCount) ~ - ("layers" -> instance.layers.toSeq) ~ - ("weights" -> instance.weights.toArray.toSeq) + val rMetadata = "class" -> instance.getClass.getName val rMetadataJson: String = compact(render(rMetadata)) sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 4fdab2dd94655..0afea4be3d1dd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -23,9 +23,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -46,6 +46,7 @@ private[r] class NaiveBayesWrapper private ( pipeline.transform(dataset) .drop(PREDICTED_LABEL_INDEX_COL) .drop(naiveBayesModel.getFeaturesCol) + .drop(naiveBayesModel.getLabelCol) } override def write: MLWriter = new NaiveBayesWrapper.NaiveBayesWrapperWriter(this) @@ -60,21 +61,16 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] { val rFormula = new RFormula() .setFormula(formula) .setForceIndexLabel(true) - RWrapperUtils.checkDataColumns(rFormula, data) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) // get labels and feature names from output schema - val schema = rFormulaModel.transform(data).schema - val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) - .asInstanceOf[NominalAttribute] - val labels = labelAttr.values.get - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val naiveBayes = new NaiveBayes() .setSmoothing(smoothing) .setModelType("bernoulli") .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala index 379007c4d948d..665e50af67d46 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala @@ -18,11 +18,12 @@ package org.apache.spark.ml.r import org.apache.spark.internal.Logging -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} +import org.apache.spark.ml.feature.{RFormula, RFormulaModel} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.Dataset -object RWrapperUtils extends Logging { +private[r] object RWrapperUtils extends Logging { /** * DataFrame column check. @@ -32,14 +33,41 @@ object RWrapperUtils extends Logging { * * @param rFormula RFormula instance * @param data Input dataset - * @return Unit */ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = { if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) { val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}" - logWarning(s"data containing ${rFormula.getFeaturesCol} column, " + + logInfo(s"data containing ${rFormula.getFeaturesCol} column, " + s"using new name $newFeaturesName instead") rFormula.setFeaturesCol(newFeaturesName) } + + if (rFormula.getForceIndexLabel && data.schema.fieldNames.contains(rFormula.getLabelCol)) { + val newLabelName = s"${Identifiable.randomUID(rFormula.getLabelCol)}" + logInfo(s"data containing ${rFormula.getLabelCol} column and we force to index label, " + + s"using new name $newLabelName instead") + rFormula.setLabelCol(newLabelName) + } + } + + /** + * Get the feature names and original labels from the schema + * of DataFrame transformed by RFormulaModel. + * + * @param rFormulaModel The RFormulaModel instance. + * @param data Input dataset. + * @return The feature names and original labels. + */ + def getFeaturesAndLabels( + rFormulaModel: RFormulaModel, + data: Dataset[_]): (Array[String], Array[String]) = { + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol)) + .asInstanceOf[NominalAttribute] + val labels = labelAttr.values.get + (features, labels) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala index 0e09e18027ca7..b59fe292349bf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala @@ -60,6 +60,10 @@ private[r] object RWrappers extends MLReader[Object] { RandomForestRegressorWrapper.load(path) case "org.apache.spark.ml.r.RandomForestClassifierWrapper" => RandomForestClassifierWrapper.load(path) + case "org.apache.spark.ml.r.GBTRegressorWrapper" => + GBTRegressorWrapper.load(path) + case "org.apache.spark.ml.r.GBTClassifierWrapper" => + GBTClassifierWrapper.load(path) case _ => throw new SparkException(s"SparkR read.ml does not support load $className") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala index b0088ddaf3b1d..0b860e5af96e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestClassificationWrapper.scala @@ -23,10 +23,10 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.feature.{IndexToString, RFormula} import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.r.RWrapperUtils._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -35,18 +35,23 @@ private[r] class RandomForestClassifierWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { - private val DTModel: RandomForestClassificationModel = + import RandomForestClassifierWrapper._ + + private val rfcModel: RandomForestClassificationModel = pipeline.stages(1).asInstanceOf[RandomForestClassificationModel] - lazy val numFeatures: Int = DTModel.numFeatures - lazy val featureImportances: Vector = DTModel.featureImportances - lazy val numTrees: Int = DTModel.getNumTrees - lazy val treeWeights: Array[Double] = DTModel.treeWeights + lazy val numFeatures: Int = rfcModel.numFeatures + lazy val featureImportances: Vector = rfcModel.featureImportances + lazy val numTrees: Int = rfcModel.getNumTrees + lazy val treeWeights: Array[Double] = rfcModel.treeWeights - def summary: String = DTModel.toDebugString + def summary: String = rfcModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(rfcModel.getFeaturesCol) + .drop(rfcModel.getLabelCol) } override def write: MLWriter = new @@ -54,6 +59,10 @@ private[r] class RandomForestClassifierWrapper private ( } private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestClassifierWrapper] { + + val PREDICTED_LABEL_INDEX_COL = "pred_label_idx" + val PREDICTED_LABEL_COL = "prediction" + def fit( // scalastyle:ignore data: DataFrame, formula: String, @@ -73,14 +82,12 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC val rFormula = new RFormula() .setFormula(formula) - RWrapperUtils.checkDataColumns(rFormula, data) + .setForceIndexLabel(true) + checkDataColumns(rFormula, data) val rFormulaModel = rFormula.fit(data) - // get feature names from output schema - val schema = rFormulaModel.transform(data).schema - val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol)) - .attributes.get - val features = featureAttrs.map(_.name.get) + // get labels and feature names from output schema + val (features, labels) = getFeaturesAndLabels(rFormulaModel, data) // assemble and fit the pipeline val rfc = new RandomForestClassifier() @@ -97,10 +104,17 @@ private[r] object RandomForestClassifierWrapper extends MLReadable[RandomForestC .setCacheNodeIds(cacheNodeIds) .setProbabilityCol(probabilityCol) .setFeaturesCol(rFormula.getFeaturesCol) + .setLabelCol(rFormula.getLabelCol) + .setPredictionCol(PREDICTED_LABEL_INDEX_COL) if (seed != null && seed.length > 0) rfc.setSeed(seed.toLong) + val idxToStr = new IndexToString() + .setInputCol(PREDICTED_LABEL_INDEX_COL) + .setOutputCol(PREDICTED_LABEL_COL) + .setLabels(labels) + val pipeline = new Pipeline() - .setStages(Array(rFormulaModel, rfc)) + .setStages(Array(rFormulaModel, rfc, idxToStr)) .fit(data) new RandomForestClassifierWrapper(pipeline, formula, features) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala index c8874407fa75e..4b9a3a731da9b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/RandomForestRegressionWrapper.scala @@ -35,18 +35,18 @@ private[r] class RandomForestRegressorWrapper private ( val formula: String, val features: Array[String]) extends MLWritable { - private val DTModel: RandomForestRegressionModel = + private val rfrModel: RandomForestRegressionModel = pipeline.stages(1).asInstanceOf[RandomForestRegressionModel] - lazy val numFeatures: Int = DTModel.numFeatures - lazy val featureImportances: Vector = DTModel.featureImportances - lazy val numTrees: Int = DTModel.getNumTrees - lazy val treeWeights: Array[Double] = DTModel.treeWeights + lazy val numFeatures: Int = rfrModel.numFeatures + lazy val featureImportances: Vector = rfrModel.featureImportances + lazy val numTrees: Int = rfrModel.getNumTrees + lazy val treeWeights: Array[Double] = rfrModel.treeWeights - def summary: String = DTModel.toDebugString + def summary: String = rfrModel.toDebugString def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(DTModel.getFeaturesCol) + pipeline.transform(dataset).drop(rfrModel.getFeaturesCol) } override def write: MLWriter = new diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 02e2384afe530..6d2c59a905ec7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -678,6 +678,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { checkpointInterval: Int = 10, seed: Long = 0L)( implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") require(intermediateRDDStorageLevel != StorageLevel.NONE, "ALS is not designed to run without persisting intermediate RDDs.") val sc = ratings.sparkContext diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index ebc6c12ddcf92..1419da874709f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -207,9 +207,9 @@ class DecisionTreeRegressionModel private[ml] ( * where gain is scaled by the number of instances passing through node * - Normalize importances for tree to sum to 1. * - * Note: Feature importance for single decision trees can have high variance due to - * correlated predictor variables. Consider using a [[RandomForestRegressor]] - * to determine feature importance instead. + * @note Feature importance for single decision trees can have high variance due to + * correlated predictor variables. Consider using a [[RandomForestRegressor]] + * to determine feature importance instead. */ @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 33cb25c8c7f66..3f9de1fe74c9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -255,7 +255,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -270,7 +270,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, wlsModel.diagInvAtWA.toArray, 1, getSolver) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } // Fit Generalized Linear Model by iteratively reweighted least squares (IRLS). @@ -284,7 +284,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val .setParent(this)) val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("2.0.0") @@ -501,8 +501,8 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine val defaultLink: Link = Log override def initialize(y: Double, weight: Double): Double = { - require(y > 0.0, "The response variable of Poisson family " + - s"should be positive, but got $y") + require(y >= 0.0, "The response variable of Poisson family " + + s"should be non-negative, but got $y") y } @@ -761,8 +761,8 @@ class GeneralizedLinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.nonEmpty private[regression] - def setSummary(summary: GeneralizedLinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -776,8 +776,9 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { - copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - .setParent(parent) + val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), + extra) + copied.setSummary(trainingSummary).setParent(parent) } /** @@ -877,8 +878,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( * Private copy of model to ensure Params are not modified outside this class. * Coefficients is not a deep copy, but that is acceptable. * - * NOTE: [[predictionCol]] must be set correctly before the value of [[model]] is set, - * and [[model]] must be set before [[predictions]] is set! + * @note [[predictionCol]] must be set correctly before the value of [[model]] is set, + * and [[model]] must be set before [[predictions]] is set! */ protected val model: GeneralizedLinearRegressionModel = origModel.copy(ParamMap.empty).setPredictionCol(predictionCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index cd7b4f2a9c56e..4d274f3a5bbf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -61,7 +61,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * @group param */ final val featureIndex: IntParam = new IntParam(this, "featureIndex", - "The index of the feature if featuresCol is a vector column, no effect otherwise.") + "The index of the feature if featuresCol is a vector column, no effect otherwise (>= 0)", + ParamValidators.gtEq(0)) /** @group getParam */ final def getFeatureIndex: Int = $(featureIndex) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 519f3bdec82df..8ea5e1e6c453a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ -import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares} +import org.apache.spark.ml.optim.WeightedLeastSquares import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -103,11 +103,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Whether to standardize the training features before fitting the model. * The coefficients of models will be always returned on the original scale, - * so it will be transparent for users. Note that with/without standardization, - * the models should be always converged to the same solution when no regularization - * is applied. In R's GLMNET package, the default behavior is true as well. + * so it will be transparent for users. * Default is true. * + * @note With/without standardization, the models should be always converged + * to the same solution when no regularization is applied. In R's GLMNET package, + * the default behavior is true as well. + * * @group setParam */ @Since("1.5.0") @@ -160,16 +162,22 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Set the solver algorithm used for optimization. * In case of linear regression, this can be "l-bfgs", "normal" and "auto". - * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton - * optimization method. "normal" denotes using Normal Equation as an analytical - * solution to the linear regression problem. - * The default value is "auto" which means that the solver algorithm is - * selected automatically. + * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. + * - "normal" denotes using Normal Equation as an analytical solution to the linear regression + * problem. This solver is limited to [[LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER]]. + * - "auto" (default) means that the solver algorithm is selected automatically. + * The Normal Equations solver will be used when possible, but this will automatically fall + * back to iterative optimization methods when needed. * * @group setParam */ @Since("1.6.0") - def setSolver(value: String): this.type = set(solver, value) + def setSolver(value: String): this.type = { + require(Set("auto", "l-bfgs", "normal").contains(value), + s"Solver $value was not supported. Supported options: auto, l-bfgs, normal") + set(solver, value) + } setDefault(solver -> "auto") /** @@ -190,7 +198,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select( - col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + col($(labelCol)), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -217,7 +225,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model.diagInvAtWA.toArray, model.objectiveHistory) - return lrModel.setSummary(trainingSummary) + return lrModel.setSummary(Some(trainingSummary)) } val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -270,7 +278,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), Array(0D)) - return model.setSummary(trainingSummary) + return model.setSummary(Some(trainingSummary)) } else { require($(regParam) == 0.0, "The standard deviation of the label is zero. " + "Model cannot be regularized.") @@ -392,7 +400,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String model, Array(0D), objectiveHistory) - model.setSummary(trainingSummary) + model.setSummary(Some(trainingSummary)) } @Since("1.4.0") @@ -404,6 +412,14 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression] { @Since("1.6.0") override def load(path: String): LinearRegression = super.load(path) + + /** + * When using [[LinearRegression.solver]] == "normal", the solver must limit the number of + * features to at most this number. The entire covariance matrix X^T^X will be collected + * to the driver. This limit helps prevent memory overflow errors. + */ + @Since("2.1.0") + val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES } /** @@ -430,8 +446,9 @@ class LinearRegressionModel private[ml] ( throw new SparkException("No training summary available for this LinearRegressionModel") } - private[regression] def setSummary(summary: LinearRegressionTrainingSummary): this.type = { - this.trainingSummary = Some(summary) + private[regression] + def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = { + this.trainingSummary = summary this } @@ -474,8 +491,7 @@ class LinearRegressionModel private[ml] ( @Since("1.4.0") override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, coefficients, intercept), extra) - if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel.setParent(parent) + newModel.setSummary(trainingSummary).setParent(parent) } /** @@ -610,8 +626,8 @@ class LinearRegressionSummary private[regression] ( * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val explainedVariance: Double = metrics.explainedVariance @@ -620,8 +636,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean absolute error, which is a risk function corresponding to the * expected value of the absolute error loss or l1-norm loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanAbsoluteError: Double = metrics.meanAbsoluteError @@ -630,8 +646,8 @@ class LinearRegressionSummary private[regression] ( * Returns the mean squared error, which is a risk function corresponding to the * expected value of the squared error loss or quadratic loss. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val meanSquaredError: Double = metrics.meanSquaredError @@ -640,8 +656,8 @@ class LinearRegressionSummary private[regression] ( * Returns the root mean squared error, which is defined as the square root of * the mean squared error. * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val rootMeanSquaredError: Double = metrics.rootMeanSquaredError @@ -650,8 +666,8 @@ class LinearRegressionSummary private[regression] ( * Returns R^2^, the coefficient of determination. * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] * - * Note: This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. - * This will change in later Spark versions. + * @note This ignores instance weights (setting all to 1.0) from [[LinearRegression.weightCol]]. + * This will change in later Spark versions. */ @Since("1.5.0") val r2: Double = metrics.r2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala index 73d813064decb..e1376927030e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMDataSource.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, DataFrameReader} * inconsistent feature dimensions. * - "vectorType": feature vector type, "sparse" (default) or "dense". * - * Note that this class is public for documentation purpose. Please don't use this class directly. + * @note This class is public for documentation purpose. Please don't use this class directly. * Rather, use the data source API as illustrated above. * * @see [[https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ LIBSVM datasets]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 7bef899a633d9..0a0bc4c006389 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -34,7 +34,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param seed Random seed. * @return tuple of ensemble models and weights: * (array of decision tree models, array of model weights) @@ -59,7 +59,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to validate a gradient boosting model - * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param input Training dataset: RDD of [[LabeledPoint]]. * @param validationInput Validation dataset. * This dataset should be different from the training dataset, * but it should follow the same distribution. @@ -98,7 +98,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ def computeInitialPredictionAndError( @@ -121,7 +121,7 @@ private[spark] object GradientBoostedTrees extends Logging { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ def updatePredictionError( @@ -162,7 +162,7 @@ private[spark] object GradientBoostedTrees extends Logging { * Method to calculate error of the base learner for the gradient boosting calculation. * Note: This method is not used by the gradient boosting algorithm but is useful for debugging * purposes. - * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param data Training dataset: RDD of [[LabeledPoint]]. * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. @@ -184,7 +184,7 @@ private[spark] object GradientBoostedTrees extends Logging { /** * Method to compute error or loss for every iteration of gradient boosting. * - * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param data RDD of [[LabeledPoint]] * @param trees Boosted Decision Tree models * @param treeWeights Learning rates at each boosting iteration. * @param loss evaluation metric. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index b504f411d256d..8ae5ca3c84b0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -82,7 +82,7 @@ private[spark] object RandomForest extends Logging { /** * Train a random forest. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @return an unweighted set of trees */ def run( @@ -343,7 +343,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] + * @param input Training data: RDD of [[TreePoint]] * @param metadata Learning and dataset metadata * @param topNodesForGroup For each tree in group, tree index -> root node. * Used for matching instances with nodes. @@ -854,10 +854,10 @@ private[spark] object RandomForest extends Logging { * and for multiclass classification with a high-arity feature, * there is one bin per category. * - * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param input Training data: RDD of [[LabeledPoint]] * @param metadata Learning and dataset metadata * @param seed random seed - * @return Splits, an Array of [[org.apache.spark.mllib.tree.model.Split]] + * @return Splits, an Array of [[Split]] * of size (numFeatures, numSplits) */ protected[tree] def findSplits( diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 57c7e44e97607..5a551533be9ca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -73,11 +73,13 @@ private[ml] trait DecisionTreeParams extends PredictorParams /** * Minimum information gain for a split to be considered at a tree node. + * Should be >= 0.0. * (default = 0.0) * @group param */ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", - "Minimum information gain for a split to be considered at a tree node.") + "Minimum information gain for a split to be considered at a tree node.", + ParamValidators.gtEq(0.0)) /** * Maximum memory in MB allocated to histogram aggregation. If too small, then 1 node will be diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0fdba1cb8814a..5d1a39f7c16db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], validationMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index bc4f9e6716ee8..e5fa5d53e3fca 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -221,7 +221,7 @@ trait MLReadable[T] { /** * Reads an ML instance from the input path, a shortcut of `read.load(path)`. * - * Note: Implementing classes should override this to be Java-friendly. + * @note Implementing classes should override this to be Java-friendly. */ @Since("1.6.0") def load(path: String): T = read.load(path) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 904000f50d0a2..034e3625e8c01 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -638,13 +638,13 @@ private[python] class PythonMLLibAPI extends Serializable { selectorType: String, numTopFeatures: Int, percentile: Double, - alpha: Double, + fpr: Double, data: JavaRDD[LabeledPoint]): ChiSqSelectorModel = { new ChiSqSelector() .setSelectorType(selectorType) .setNumTopFeatures(numTopFeatures) .setPercentile(percentile) - .setAlpha(alpha) + .setFpr(fpr) .fit(data.rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index d851b983349c9..4b650000736e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -202,9 +202,11 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] { * Train a classification model for Binary Logistic Regression * using Stochastic Gradient Descent. By default L2 regularization is used, * which can be changed via `LogisticRegressionWithSGD.optimizer`. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. + * * Using [[LogisticRegressionWithLBFGS]] is recommended over this. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("0.8.0") class LogisticRegressionWithSGD private[mllib] ( @@ -239,7 +241,8 @@ class LogisticRegressionWithSGD private[mllib] ( /** * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("0.8.0") @deprecated("Use ml.classification.LogisticRegression or LogisticRegressionWithLBFGS", "2.0.0") @@ -252,7 +255,6 @@ object LogisticRegressionWithSGD { * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -260,6 +262,8 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -276,13 +280,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. - * @param miniBatchFraction Fraction of data to be used per iteration. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -298,13 +302,13 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using the specified step size. We use the entire data * set to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. - * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -318,11 +322,12 @@ object LogisticRegressionWithSGD { * Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed * number of iterations of gradient descent using a step size of 1.0. We use the entire data set * to update the gradient in each iteration. - * NOTE: Labels used in Logistic Regression should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * + * @note Labels used in Logistic Regression should be {0, 1} */ @Since("1.0.0") def train( @@ -335,8 +340,6 @@ object LogisticRegressionWithSGD { /** * Train a classification model for Multinomial/Binary Logistic Regression using * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default. - * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1} - * for k classes multi-label classification problem. * * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization * penalty to all elements including the intercept. If this is called with one of @@ -344,6 +347,9 @@ object LogisticRegressionWithSGD { * into a call to ml.LogisticRegression, otherwise this will use the existing mllib * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the * intercept. + * + * @note Labels used in Logistic Regression should be {0, 1, ..., k - 1} + * for k classes multi-label classification problem. */ @Since("1.1.0") class LogisticRegressionWithLBFGS diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 33561be4b5bc1..767d056861a8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -364,12 +364,12 @@ class NaiveBayes private ( val nb = new NewNaiveBayes() .setModelType(modelType) .setSmoothing(lambda) - .setIsML(false) val dataset = data.map { case LabeledPoint(label, features) => (label, features.asML) } .toDF("label", "features") - val newModel = nb.fit(dataset) + // mllib NaiveBayes allows input labels like {-1, +1}, so set `positiveLabel` as false. + val newModel = nb.trainWithLabelCheck(dataset, positiveLabel = false) val pi = newModel.pi.toArray val theta = Array.fill[Double](newModel.numClasses, newModel.numFeatures)(0.0) @@ -378,7 +378,7 @@ class NaiveBayes private ( theta(i)(j) = v } - require(newModel.oldLabels != null, + assert(newModel.oldLabels != null, "The underlying ML NaiveBayes training does not produce labels.") new NaiveBayesModel(newModel.oldLabels, pi, theta, modelType) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 7c3ccbb40b812..aec1526b55c49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -125,7 +125,8 @@ object SVMModel extends Loader[SVMModel] { /** * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. - * NOTE: Labels used in SVM should be {0, 1}. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") class SVMWithSGD private ( @@ -158,7 +159,9 @@ class SVMWithSGD private ( } /** - * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}. + * Top-level methods for calling SVM. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") object SVMWithSGD { @@ -169,8 +172,6 @@ object SVMWithSGD { * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. * - * NOTE: Labels used in SVM should be {0, 1}. - * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @param stepSize Step size to be used for each iteration of gradient descent. @@ -178,6 +179,8 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * + * @note Labels used in SVM should be {0, 1}. */ @Since("0.8.0") def train( @@ -195,7 +198,8 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. - * NOTE: Labels used in SVM should be {0, 1} + * + * @note Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -217,13 +221,14 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train( @@ -238,11 +243,12 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using a step size of 1.0. We use the entire data set to * update the gradient in each iteration. - * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * + * @note Labels used in SVM should be {0, 1} */ @Since("0.8.0") def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 43193adf3e184..56cdeea5f7a3f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -41,14 +41,14 @@ import org.apache.spark.util.Utils * While this process is generally guaranteed to converge, it is not guaranteed * to find a global optimum. * - * Note: For high-dimensional data (with many features), this algorithm may perform poorly. - * This is due to high-dimensional data (a) making it difficult to cluster at all (based - * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. - * * @param k Number of independent Gaussians in the mixture model. * @param convergenceTol Maximum change in log-likelihood at which convergence * is considered to have occurred. * @param maxIterations Maximum number of iterations allowed. + * + * @note For high-dimensional data (with many features), this algorithm may perform poorly. + * This is due to high-dimensional data (a) making it difficult to cluster at all (based + * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions. */ @Since("1.3.0") class GaussianMixture private ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index ed9c064879d01..fa72b72e2d921 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -56,14 +56,18 @@ class KMeans private ( def this() = this(2, 20, KMeans.K_MEANS_PARALLEL, 2, 1e-4, Utils.random.nextLong()) /** - * Number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. */ @Since("1.4.0") def getK: Int = k /** - * Set the number of clusters to create (k). Note that it is possible for fewer than k clusters to + * Set the number of clusters to create (k). + * + * @note It is possible for fewer than k clusters to * be returned, for example, if there are fewer than k distinct points to cluster. Default: 2. */ @Since("0.8.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index d999b9be8e8ac..7c52abdeaac22 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -175,7 +175,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.3.0") @@ -187,7 +187,7 @@ class LDA private ( * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. * * If set to -1, then topicConcentration is set automatically. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 90d8a558f10d4..b5b0e64a2a6c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -66,7 +66,7 @@ abstract class LDAModel private[clustering] extends Saveable { * * This is the parameter to a symmetric Dirichlet distribution. * - * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * @note The topics' distributions over terms are called "beta" in the original LDA paper * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. */ @Since("1.5.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index ae324f86fe6d1..7365ea1f200da 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -93,9 +93,11 @@ final class EMLDAOptimizer extends LDAOptimizer { /** * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with - * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * care. * * Default: true + * + * @note Checkpoints will be cleaned up via reference counting, regardless. */ @Since("2.0.0") def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { @@ -348,7 +350,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { * Mini-batch fraction in (0, 1], which sets the fraction of document sampled and used in * each iteration. * - * Note that this should be adjusted in synch with [[LDA.setMaxIterations()]] + * @note This should be adjusted in synch with [[LDA.setMaxIterations()]] * so the entire corpus is used. Specifically, set both so that * maxIterations * miniBatchFraction >= 1. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index f0779491e6374..003d1411a9cf7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -39,7 +39,7 @@ private[evaluation] object AreaUnderCurve { /** * Returns the area under the given curve. * - * @param curve a RDD of ordered 2D points stored in pairs representing a curve + * @param curve an RDD of ordered 2D points stored in pairs representing a curve */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index f8276de4f23d4..f9156b642785f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -161,7 +161,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { Loader.checkSchema[Data](dataFrame.schema) val features = dataArray.rdd.map { - case Row(feature: Int) => (feature) + case Row(feature: Int) => feature }.collect() new ChiSqSelectorModel(features) @@ -171,18 +171,20 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { /** * Creates a ChiSquared feature selector. - * The selector supports three selection methods: `kbest`, `percentile` and `fpr`. - * `kbest` chooses the `k` top features according to a chi-squared test. - * `percentile` is similar but chooses a fraction of all features instead of a fixed number. - * `fpr` chooses all features whose false positive rate meets some threshold. - * By default, the selection method is `kbest`, the default number of top features is 50. + * The selector supports different selection methods: `numTopFeatures`, `percentile`, `fpr`. + * - `numTopFeatures` chooses a fixed number of top features according to a chi-squared test. + * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. + * - `fpr` chooses all features whose p-value is below a threshold, thus controlling the false + * positive rate of selection. + * By default, the selection method is `numTopFeatures`, with the default number of top features + * set to 50. */ @Since("1.3.0") class ChiSqSelector @Since("2.1.0") () extends Serializable { var numTopFeatures: Int = 50 var percentile: Double = 0.1 - var alpha: Double = 0.05 - var selectorType = ChiSqSelector.KBest + var fpr: Double = 0.05 + var selectorType = ChiSqSelector.NumTopFeatures /** * The is the same to call this() and setNumTopFeatures(numTopFeatures) @@ -207,15 +209,15 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } @Since("2.1.0") - def setAlpha(value: Double): this.type = { - require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") - alpha = value + def setFpr(value: Double): this.type = { + require(0.0 <= value && value <= 1.0, "FPR must be in [0,1]") + fpr = value this } @Since("2.1.0") def setSelectorType(value: String): this.type = { - require(ChiSqSelector.supportedSelectorTypes.toSeq.contains(value), + require(ChiSqSelector.supportedSelectorTypes.contains(value), s"ChiSqSelector Type: $value was not supported.") selectorType = value this @@ -232,7 +234,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { - case ChiSqSelector.KBest => + case ChiSqSelector.NumTopFeatures => chiSqTestResult .sortBy { case (res, _) => res.pValue } .take(numTopFeatures) @@ -242,7 +244,7 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => chiSqTestResult - .filter { case (res, _) => res.pValue < alpha } + .filter { case (res, _) => res.pValue < fpr } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } @@ -251,22 +253,17 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { } } -@Since("2.1.0") -object ChiSqSelector { +private[spark] object ChiSqSelector { - /** String name for `kbest` selector type. */ - private[spark] val KBest: String = "kbest" + /** String name for `numTopFeatures` selector type. */ + val NumTopFeatures: String = "numTopFeatures" /** String name for `percentile` selector type. */ - private[spark] val Percentile: String = "percentile" + val Percentile: String = "percentile" /** String name for `fpr` selector type. */ private[spark] val FPR: String = "fpr" - /** Set of selector type and param pairs that ChiSqSelector supports. */ - private[spark] val supportedTypeAndParamPairs = Set(KBest -> "numTopFeatures", - Percentile -> "percentile", FPR -> "alpha") - /** Set of selector types that ChiSqSelector supports. */ - private[spark] val supportedSelectorTypes = supportedTypeAndParamPairs.map(_._1) + val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, FPR) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index fbd217af74ecb..c94d7890cf557 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * - * Note: Users should not implement this interface. + * @note Users should not implement this interface. */ @SQLUserDefinedType(udt = classOf[VectorUDT]) @Since("1.0.0") @@ -132,7 +132,9 @@ sealed trait Vector extends Serializable { /** * Number of active entries. An "active entry" is an element which is explicitly stored, - * regardless of its value. Note that inactive entries have value 0. + * regardless of its value. + * + * @note Inactive entries have value 0. */ @Since("1.4.0") def numActives: Int diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 377be6bfb9886..03866753b50ee 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -451,7 +451,7 @@ class BlockMatrix @Since("1.3.0") ( * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause * some performance issues until support for multiplying two sparse matrices is added. * - * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when + * @note The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added * with each other. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index b03b3ecde94f4..809906a158337 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -188,8 +188,9 @@ class IndexedRowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be - * computed on matrices with more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index ec32e37afb792..4b120332ab8d8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -106,8 +106,9 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the Gramian matrix `A^T A`. Note that this cannot be computed on matrices with - * more than 65535 columns. + * Computes the Gramian matrix `A^T A`. + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeGramianMatrix(): Matrix = { @@ -168,9 +169,6 @@ class RowMatrix @Since("1.0.0") ( * ARPACK is set to 300 or k * 3, whichever is larger. The numerical tolerance for ARPACK's * eigen-decomposition is set to 1e-10. * - * @note The conditions that decide which method to use internally and the default parameters are - * subject to change. - * * @param k number of leading singular values to keep (0 < k <= n). * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values @@ -180,6 +178,9 @@ class RowMatrix @Since("1.0.0") ( * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) * are treated as zero, where sigma(0) is the largest singular value. * @return SingularValueDecomposition(U, s, V). U = null if computeU = false. + * + * @note The conditions that decide which method to use internally and the default parameters are + * subject to change. */ @Since("1.0.0") def computeSVD( @@ -319,9 +320,11 @@ class RowMatrix @Since("1.0.0") ( } /** - * Computes the covariance matrix, treating each row as an observation. Note that this cannot - * be computed on matrices with more than 65535 columns. + * Computes the covariance matrix, treating each row as an observation. + * * @return a local dense matrix of size n x n + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.0.0") def computeCovariance(): Matrix = { @@ -369,12 +372,12 @@ class RowMatrix @Since("1.0.0") ( * The row data do not need to be "centered" first; it is not necessary for * the mean of each column to be 0. * - * Note that this cannot be computed on matrices with more than 65535 columns. - * * @param k number of top principal components. * @return a matrix of size n-by-k, whose columns are principal components, and * a vector of values which indicate how much variance each principal component * explains + * + * @note This cannot be computed on matrices with more than 65535 columns. */ @Since("1.6.0") def computePrincipalComponentsAndExplainedVariance(k: Int): (Matrix, Vector) = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 81e64de4e5b5d..c49e72646bf13 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -305,7 +305,8 @@ class LeastSquaresGradient extends Gradient { * :: DeveloperApi :: * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification. * See also the documentation for the precise formulation. - * NOTE: This assumes that the labels are {0,1} + * + * @note This assumes that the labels are {0,1} */ @DeveloperApi class HingeGradient extends Gradient { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index 426bb818c9266..f5ca1c221d66b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.pmml.export import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.beans.BeanProperty @@ -34,7 +34,7 @@ private[mllib] trait PMMLModelExport { val version = getClass.getPackage.getImplementationVersion val app = new Application("Apache Spark MLlib").setVersion(version) val timestamp = new Timestamp() - .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss").format(new Date())) + .addContent(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.US).format(new Date())) val header = new Header() .setApplication(app) .setTimestamp(timestamp) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 0f7857b8d8627..005119616f063 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -31,7 +31,7 @@ import org.apache.spark.rdd.RDD class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** - * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Returns an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index cc9ee15738ad6..0039db7ecbbc7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -236,6 +236,8 @@ class ALS private ( */ @Since("0.8.0") def run(ratings: RDD[Rating]): MatrixFactorizationModel = { + require(!ratings.isEmpty(), s"No ratings available from $ratings") + val sc = ratings.context val numUserBlocks = if (this.numUserBlocks == -1) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index c642573ccba6d..24e4dcccc843f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -43,14 +43,14 @@ import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * - * Note: If you create the model directly using constructor, please be aware that fast prediction - * requires cached user/product features and their associated partitioners. - * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. + * + * @note If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. */ @Since("0.8.0") class MatrixFactorizationModel @Since("0.8.0") ( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index f3159f7e724cc..925fdf4d7e7bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -60,15 +60,15 @@ object Statistics { * Compute the correlation matrix for the input RDD of Vectors using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note that for Spearman, a rank correlation, we need to create an RDD[Double] for each column - * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], - * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to - * avoid recomputing the common lineage. - * * @param X an RDD[Vector] for which the correlation matrix is to be computed. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return Correlation matrix comparing columns in X. + * + * @note For Spearman, a rank correlation, we need to create an RDD[Double] for each column + * and sort it in order to retrieve the ranks and then join the columns back into an RDD[Vector], + * which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to + * avoid recomputing the common lineage. */ @Since("1.1.0") def corr(X: RDD[Vector], method: String): Matrix = Correlations.corrMatrix(X, method) @@ -77,12 +77,12 @@ object Statistics { * Compute the Pearson correlation for the input RDDs. * Returns NaN if either vector has 0 variance. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @return A Double containing the Pearson correlation between the two input RDD[Double]s + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double]): Double = Correlations.corr(x, y) @@ -98,15 +98,15 @@ object Statistics { * Compute the correlation for the input RDDs using the specified method. * Methods currently supported: `pearson` (default), `spearman`. * - * Note: the two input RDDs need to have the same number of partitions and the same number of - * elements in each partition. - * * @param x RDD[Double] of the same cardinality as y. * @param y RDD[Double] of the same cardinality as x. * @param method String specifying the method to use for computing correlation. * Supported: `pearson` (default), `spearman` * @return A Double containing the correlation between the two input RDD[Double]s using the * specified method. + * + * @note The two input RDDs need to have the same number of partitions and the same number of + * elements in each partition. */ @Since("1.1.0") def corr(x: RDD[Double], y: RDD[Double], method: String): Double = Correlations.corr(x, y, method) @@ -122,15 +122,15 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the * expected distribution. * - * Note: the two input Vectors need to have the same size. - * `observed` cannot contain negative values. - * `expected` cannot contain nonpositive values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @param expected Vector containing the expected categorical counts/relative frequencies. * `expected` is rescaled if the `expected` sum differs from the `observed` sum. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note The two input Vectors need to have the same size. + * `observed` cannot contain negative values. + * `expected` cannot contain nonpositive values. */ @Since("1.1.0") def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { @@ -141,11 +141,11 @@ object Statistics { * Conduct Pearson's chi-squared goodness of fit test of the observed data against the uniform * distribution, with each category having an expected frequency of `1 / observed.size`. * - * Note: `observed` cannot contain negative values. - * * @param observed Vector containing the observed categorical counts/relative frequencies. * @return ChiSquaredTest object containing the test statistic, degrees of freedom, p-value, * the method used, and the null hypothesis. + * + * @note `observed` cannot contain negative values. */ @Since("1.1.0") def chiSqTest(observed: Vector): ChiSqTestResult = ChiSqTest.chiSquared(observed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 36feab7859b43..d846c43cf2913 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -75,10 +75,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -86,6 +82,10 @@ object DecisionTree extends Serializable with Logging { * of decision tree (classification or regression), feature type (continuous, * categorical), depth of the tree, quantile calculation strategy, etc. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { @@ -96,10 +96,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -108,6 +104,10 @@ object DecisionTree extends Serializable with Logging { * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means * 1 internal node + 2 leaf nodes). * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( @@ -123,10 +123,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -136,6 +132,10 @@ object DecisionTree extends Serializable with Logging { * 1 internal node + 2 leaf nodes). * @param numClasses Number of classes for classification. Default value of 2. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.2.0") def train( @@ -152,10 +152,6 @@ object DecisionTree extends Serializable with Logging { * Method to train a decision tree model. * The method supports binary and multiclass classification and regression. * - * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] - * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] - * is recommended to clearly separate classification and regression. - * * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * For classification, labels should take values {0, 1, ..., numClasses-1}. * For regression, labels are real numbers. @@ -170,6 +166,10 @@ object DecisionTree extends Serializable with Logging { * indicates that feature n is categorical with k categories * indexed from 0: {0, 1, ..., k-1}. * @return DecisionTreeModel that can be used for prediction. + * + * @note Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]] + * and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]] + * is recommended to clearly separate classification and regression. */ @Since("1.0.0") def train( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index de14ddf024d75..09274a2e1b2ac 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -42,11 +42,13 @@ trait Loss extends Serializable { /** * Method to calculate error of the base learner for the gradient boosting calculation. - * Note: This method is not used by the gradient boosting algorithm but is useful for debugging - * purposes. + * * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return Measure of model error on data + * + * @note This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. */ @Since("1.2.0") def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { @@ -55,11 +57,13 @@ trait Loss extends Serializable { /** * Method to calculate loss when the predictions are already known. - * Note: This method is used in the method evaluateEachIteration to avoid recomputing the - * predicted values from previously fit trees. + * * @param prediction Predicted label. * @param label True label. * @return Measure of model error on datapoint. + * + * @note This method is used in the method evaluateEachIteration to avoid recomputing the + * predicted values from previously fit trees. */ private[spark] def computeError(prediction: Double, label: Double): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala index 657ed0a8ecda8..299950785e420 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -187,7 +187,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param initTreeWeight: learning rate assigned to the first tree. * @param initTree: first DecisionTreeModel. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to every sample. */ @Since("1.4.0") @@ -213,7 +213,7 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] { * @param treeWeight: Learning rate. * @param tree: Tree using which the prediction and error should be updated. * @param loss: evaluation metric. - * @return a RDD with each element being a zip of the prediction and error + * @return an RDD with each element being a zip of the prediction and error * corresponding to each sample. */ @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 6413ca1f8b19e..dafc6c200f95f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } + test("Pipeline.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF)) + val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10)) + + assert(copied.uid === pipeline.uid, + "copy should create an instance with the same UID") + assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + test("PipelineModel.copy") { val hashingTF = new HashingTF() .setNumFeatures(100) - val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF)) + .setParent(new Pipeline()) val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) - require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + + assert(copied.uid === model.uid, + "copy should create an instance with the same UID") + assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, "copy should handle extra stage params") + assert(copied.parent === model.parent, + "copy should create an instance with the same parent") } test("pipeline model constructors") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala new file mode 100644 index 0000000000000..03e0c536a973e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PredictorSuite._ + + test("should support all NumericType labels and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF("label", "features") + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + + val predictor = new MockPredictor() + + types.foreach { t => + predictor.fit(df.select(col("label").cast(t), col("features"))) + } + + intercept[IllegalArgumentException] { + predictor.fit(df.select(col("label").cast(StringType), col("features"))) + } + } +} + +object PredictorSuite { + + class MockPredictor(override val uid: String) + extends Predictor[Vector, MockPredictor, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictor")) + + override def train(dataset: Dataset[_]): MockPredictionModel = { + require(dataset.schema("label").dataType == DoubleType) + new MockPredictionModel(uid) + } + + override def copy(extra: ParamMap): MockPredictor = + throw new NotImplementedError() + } + + class MockPredictionModel(override val uid: String) + extends PredictionModel[Vector, MockPredictionModel] { + + def this() = this(Identifiable.randomUID("mockpredictormodel")) + + override def predict(features: Vector): Double = + throw new NotImplementedError() + + override def copy(extra: ParamMap): MockPredictionModel = + throw new NotImplementedError() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index bc631dc6d3149..e360542eae2ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -141,6 +141,14 @@ class LogisticRegressionSuite assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) assert(model.hasParent) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) } test("empty probabilityCol") { @@ -251,9 +259,6 @@ class LogisticRegressionSuite mlr.setFitIntercept(false) val mlrModel = mlr.fit(smallMultinomialDataset) assert(mlrModel.interceptVector === Vectors.sparse(3, Seq())) - - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { @@ -1807,7 +1812,6 @@ class LogisticRegressionSuite .objectiveHistory .sliding(2) .forall(x => x(0) >= x(1))) - } test("binary logistic regression with weighted data") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f2368a9f8dad5..fc491cd6161fd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -41,6 +42,13 @@ class BisectingKMeansSuite assert(bkm.getPredictionCol === "prediction") assert(bkm.getMaxIter === 20) assert(bkm.getMinDivisibleClusterSize === 1.0) + val model = bkm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("setter/getter") { @@ -101,6 +109,9 @@ class BisectingKMeansSuite assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 003fa6abf6597..07299123f8a47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -43,6 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getPredictionCol === "prediction") assert(gm.getMaxIter === 100) assert(gm.getTol === 0.01) + val model = gm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("set parameters") { @@ -103,6 +111,9 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("read/write") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ca392653557c4..c1b7242e11a8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -47,6 +48,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + val model = kmeans.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("set parameters") { @@ -115,6 +123,9 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusterSizes.length === k) assert(clusterSizes.sum === numRows) assert(clusterSizes.forall(_ >= 0)) + + model.setSummary(None) + assert(!model.hasSummary) } test("KMeansModel transform with non-default feature and prediction cols") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 6af06d82d671a..80970fd744881 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -19,85 +19,72 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.feature import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Dataset, Row} class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - test("Test Chi-Square selector") { - import testImplicits._ - val data = Seq( - LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), - LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), - LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), - LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0))) - ) + @transient var dataset: Dataset[_] = _ - val preFilteredData = Seq( - Vectors.dense(8.0), - Vectors.dense(0.0), - Vectors.dense(0.0), - Vectors.dense(8.0) - ) + override def beforeAll(): Unit = { + super.beforeAll() - val df = sc.parallelize(data.zip(preFilteredData)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") - - val selector = new ChiSqSelector() - .setSelectorType("kbest") - .setNumTopFeatures(1) - .setFeaturesCol("data") - .setLabelCol("label") - .setOutputCol("filtered") - - selector.fit(df).transform(df).select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } - - selector.setSelectorType("percentile").setPercentile(0.34).fit(df).transform(df) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + // Toy dataset, including the top feature for a chi-squared test. + // These data are chosen such that each feature's test has a distinct p-value. + /* To verify the results with R, run: + library(stats) + x1 <- c(8.0, 0.0, 0.0, 7.0, 8.0) + x2 <- c(7.0, 9.0, 9.0, 9.0, 7.0) + x3 <- c(0.0, 6.0, 8.0, 5.0, 3.0) + y <- c(0.0, 1.0, 1.0, 2.0, 2.0) + chisq.test(x1,y) + chisq.test(x2,y) + chisq.test(x3,y) + */ + dataset = spark.createDataFrame(Seq( + (0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0))), Vectors.dense(8.0)), + (1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0))), Vectors.dense(0.0)), + (1.0, Vectors.dense(Array(0.0, 9.0, 8.0)), Vectors.dense(0.0)), + (2.0, Vectors.dense(Array(7.0, 9.0, 5.0)), Vectors.dense(7.0)), + (2.0, Vectors.dense(Array(8.0, 7.0, 3.0)), Vectors.dense(8.0)) + )).toDF("label", "features", "topFeature") + } - val preFilteredData2 = Seq( - Vectors.dense(8.0, 7.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(0.0, 9.0), - Vectors.dense(8.0, 9.0) - ) + test("params") { + ParamsSuite.checkParams(new ChiSqSelector) + val model = new ChiSqSelectorModel("myModel", + new org.apache.spark.mllib.feature.ChiSqSelectorModel(Array(1, 3, 4))) + ParamsSuite.checkParams(model) + } - val df2 = sc.parallelize(data.zip(preFilteredData2)) - .map(x => (x._1.label, x._1.features, x._2)) - .toDF("label", "data", "preFilteredData") + test("Test Chi-Square selector: numTopFeatures") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(1) + ChiSqSelectorSuite.testSelector(selector, dataset) + } - selector.setSelectorType("fpr").setAlpha(0.2).fit(df2).transform(df2) - .select("filtered", "preFilteredData").collect().foreach { - case Row(vec1: Vector, vec2: Vector) => - assert(vec1 ~== vec2 absTol 1e-1) - } + test("Test Chi-Square selector: percentile") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.34) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelector read/write") { - val t = new ChiSqSelector() - .setFeaturesCol("myFeaturesCol") - .setLabelCol("myLabelCol") - .setOutputCol("myOutputCol") - .setNumTopFeatures(2) - testDefaultReadWrite(t) + test("Test Chi-Square selector: fpr") { + val selector = new ChiSqSelector() + .setOutputCol("filtered").setSelectorType("fpr").setFpr(0.2) + ChiSqSelectorSuite.testSelector(selector, dataset) } - test("ChiSqSelectorModel read/write") { - val oldModel = new feature.ChiSqSelectorModel(Array(1, 3)) - val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel) - val newInstance = testDefaultReadWrite(instance) - assert(newInstance.selectedFeatures === instance.selectedFeatures) + test("read/write") { + def checkModelData(model: ChiSqSelectorModel, model2: ChiSqSelectorModel): Unit = { + assert(model.selectedFeatures === model2.selectedFeatures) + } + val nb = new ChiSqSelector + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { @@ -108,3 +95,25 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext } } } + +object ChiSqSelectorSuite { + + private def testSelector(selector: ChiSqSelector, dataset: Dataset[_]): Unit = { + selector.fit(dataset).transform(dataset).select("filtered", "topFeature").collect() + .foreach { case Row(vec1: Vector, vec2: Vector) => + assert(vec1 ~== vec2 absTol 1e-1) + } + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "selectorType" -> "percentile", + "numTopFeatures" -> 1, + "percentile" -> 0.12, + "outputCol" -> "myOutput" + ) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index d0aa2cdfe0fd1..b923bacce23ca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ +import org.apache.spark.ml.recommendation.ALS.Rating import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -539,6 +540,13 @@ class ALSSuite }.getMessage.contains("was out of Integer range")) } } + + test("SPARK-18268: ALS with empty RDD should fail with better message") { + val ratings = sc.parallelize(Array.empty[Rating[Int]]) + intercept[IllegalArgumentException] { + ALS.train(ratings) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ac1ef5feb95ba..9b0fa67630d2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ @@ -44,6 +44,7 @@ class GeneralizedLinearRegressionSuite @transient var datasetGaussianInverse: DataFrame = _ @transient var datasetBinomial: DataFrame = _ @transient var datasetPoissonLog: DataFrame = _ + @transient var datasetPoissonLogWithZero: DataFrame = _ @transient var datasetPoissonIdentity: DataFrame = _ @transient var datasetPoissonSqrt: DataFrame = _ @transient var datasetGammaInverse: DataFrame = _ @@ -88,6 +89,12 @@ class GeneralizedLinearRegressionSuite xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log").toDF() + datasetPoissonLogWithZero = generateGeneralizedLinearRegressionInput( + intercept = -1.5, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), + xVariance = Array(0.7, 1.2), nPoints = 100, seed, noiseLevel = 0.01, + family = "poisson", link = "log") + .map{x => LabeledPoint(if (x.label < 0.7) 0.0 else x.label, x.features)}.toDF() + datasetPoissonIdentity = generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -139,6 +146,10 @@ class GeneralizedLinearRegressionSuite label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLog") + datasetPoissonLogWithZero.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile( + "target/tmp/GeneralizedLinearRegressionSuite/datasetPoissonLogWithZero") datasetPoissonIdentity.rdd.map { case Row(label: Double, features: Vector) => label + "," + features.toArray.mkString(",") }.repartition(1).saveAsTextFile( @@ -183,6 +194,11 @@ class GeneralizedLinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") @@ -453,6 +469,40 @@ class GeneralizedLinearRegressionSuite } } + test("generalized linear regression: poisson family against glm (with zero values)") { + /* + R code: + f1 <- data$V1 ~ data$V2 + data$V3 - 1 + f2 <- data$V1 ~ data$V2 + data$V3 + + data <- read.csv("path", header=FALSE) + for (formula in c(f1, f2)) { + model <- glm(formula, family="poisson", data=data) + print(as.vector(coef(model))) + } + [1] 0.4272661 -0.1565423 + [1] -3.6911354 0.6214301 0.1295814 + */ + val expected = Seq( + Vectors.dense(0.0, 0.4272661, -0.1565423), + Vectors.dense(-3.6911354, 0.6214301, 0.1295814)) + + import GeneralizedLinearRegression._ + + var idx = 0 + val link = "log" + val dataset = datasetPoissonLogWithZero + for (fitIntercept <- Seq(false, true)) { + val trainer = new GeneralizedLinearRegression().setFamily("poisson").setLink(link) + .setFitIntercept(fitIntercept).setLinkPredictionCol("linkPrediction") + val model = trainer.fit(dataset) + val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) + assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with poisson family, " + + s"$link link and fitIntercept = $fitIntercept (with zero values).") + idx += 1 + } + } + test("generalized linear regression: gamma family against glm") { /* R code: diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c0e8afbf5e346..0be82742a33be 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} @@ -143,6 +143,11 @@ class LinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + model.setSummary(None) + assert(!model.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 87100ae2e342f..4463a9b6e543a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -78,6 +78,10 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setSeed(42L) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala index ac702b4b7c69e..77219e500617d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -54,33 +54,34 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext { LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(8.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(8.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(0.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0)))) val model = new ChiSqSelector(1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } - test("ChiSqSelector by FPR transform test (sparse & dense vector)") { + test("ChiSqSelector by fpr transform test (sparse & dense vector)") { val labeledDiscreteData = sc.parallelize( Seq(LabeledPoint(0.0, Vectors.sparse(4, Array((0, 8.0), (1, 7.0)))), LabeledPoint(1.0, Vectors.sparse(4, Array((1, 9.0), (2, 6.0), (3, 4.0)))), LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0, 4.0))), LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0, 9.0)))), 2) val preFilteredData = - Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + Seq(LabeledPoint(0.0, Vectors.dense(Array(0.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(1.0, Vectors.dense(Array(4.0))), LabeledPoint(2.0, Vectors.dense(Array(9.0)))) - val model = new ChiSqSelector().setSelectorType("fpr").setAlpha(0.1).fit(labeledDiscreteData) + val model: ChiSqSelectorModel = new ChiSqSelector().setSelectorType("fpr") + .setFpr(0.1).fit(labeledDiscreteData) val filteredData = labeledDiscreteData.map { lp => LabeledPoint(lp.label, model.transform(lp.features)) - }.collect().toSet - assert(filteredData == preFilteredData) + }.collect().toSeq + assert(filteredData === preFilteredData) } test("model load / save") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index d9dc557e3b2b9..b08ad99f4f204 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -188,6 +188,13 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext { testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false) } + test("SPARK-18268: ALS with empty RDD should fail with better message") { + val ratings = sc.parallelize(Array.empty[Rating]) + intercept[IllegalArgumentException] { + new ALS().run(ratings) + } + } + /** * Test if we can correctly factorize R = U * P where U and P are of known rank. * diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index e4e9be39ff6f9..665708a780c48 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -155,13 +155,17 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) - val lines = outputDir.listFiles() + val sources = outputDir.listFiles() .filter(_.getName.startsWith("part-")) - .flatMap(Source.fromFile(_).getLines()) - .toSet - val expected = Set("1.1 1:1.23 3:4.56", "0.0 1:1.01 2:2.02 3:3.03") - assert(lines === expected) - Utils.deleteRecursively(tempDir) + .map(Source.fromFile) + Utils.tryWithSafeFinally { + val lines = sources.flatMap(_.getLines()).toSet + val expected = Set("1.1 1:1.23 3:4.56", "0.0 1:1.01 2:2.02 3:3.03") + assert(lines === expected) + } { + sources.foreach(_.close()) + Utils.deleteRecursively(tempDir) + } } test("appendBias") { diff --git a/pom.xml b/pom.xml index e92d4eac1b7d3..89e92dbae4691 100644 --- a/pom.xml +++ b/pom.xml @@ -569,7 +569,7 @@ io.netty netty-all - 4.0.41.Final + 4.0.42.Final io.netty @@ -1446,6 +1446,11 @@ jline jline + + + org.json + json + @@ -2482,6 +2487,26 @@ maven-javadoc-plugin -Xdoclint:all -Xdoclint:-missing + + + example + a + Example: + + + note + a + Note: + + + group + X + + + tparam + X + + @@ -2704,6 +2729,54 @@ + + + snapshots-and-staging + + + https://repository.apache.org/content/groups/staging/ + https://repository.apache.org/content/repositories/snapshots/ + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + + + + ASF Staging + ${asf.staging} + + + ASF Snapshots + ${asf.snapshots} + + true + + + false + + + + +