diff --git a/.github/PULL_REQUEST_TEMPLATE b/.github/PULL_REQUEST_TEMPLATE index 0e41cf1826453..5af45d6fa7988 100644 --- a/.github/PULL_REQUEST_TEMPLATE +++ b/.github/PULL_REQUEST_TEMPLATE @@ -7,4 +7,4 @@ (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) -Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. +Please review http://spark.apache.org/contributing.html before opening a pull request. diff --git a/.gitignore b/.gitignore index 39d17e1793f77..1d91b43c23fa7 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ dependency-reduced-pom.xml derby.log dev/create-release/*final dev/create-release/*txt +dev/pr-deps/ dist/ docs/_site docs/api @@ -57,6 +58,8 @@ project/plugins/project/build.properties project/plugins/src_managed/ project/plugins/target/ python/lib/pyspark.zip +python/deps +python/pyspark/python reports/ scalastyle-on-compile.generated.xml scalastyle-output.xml diff --git a/.travis.yml b/.travis.yml index 8739849a20798..d7e9f8c0290e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,6 @@ dist: trusty # 2. Choose language and target JDKs for parallel builds. language: java jdk: - - oraclejdk7 - oraclejdk8 # 3. Setup cache directory for SBT and Maven. @@ -44,7 +43,7 @@ notifications: # 5. Run maven install before running lint-java. install: - export MAVEN_SKIP_RC=1 - - build/mvn -T 4 -q -DskipTests -Pmesos -Pyarn -Phadoop-2.3 -Pkinesis-asl -Phive -Phive-thriftserver install + - build/mvn -T 4 -q -DskipTests -Pmesos -Pyarn -Pkinesis-asl -Phive -Phive-thriftserver install # 6. Run lint-java. script: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1a8206abe3838..8fdd5aa9e7dfb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,12 @@ ## Contributing to Spark *Before opening a pull request*, review the -[Contributing to Spark wiki](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark). +[Contributing to Spark guide](http://spark.apache.org/contributing.html). It lists steps that are required before creating a PR. In particular, consider: - Is the change important and ready enough to ask the community to spend time reviewing? - Have you searched for existing, related JIRAs and pull requests? -- Is this a new feature that can stand alone as a [third party project](https://cwiki.apache.org/confluence/display/SPARK/Third+Party+Projects) ? +- Is this a new feature that can stand alone as a [third party project](http://spark.apache.org/third-party-projects.html) ? - Is the change being proposed clearly explained and motivated? When you contribute code, you affirm that the contribution is your original work and that you diff --git a/LICENSE b/LICENSE index 7950dd6ceb6db..c21032a1fd274 100644 --- a/LICENSE +++ b/LICENSE @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/NOTICE b/NOTICE index 69b513ea3ba3c..f4b64b5c3f470 100644 --- a/NOTICE +++ b/NOTICE @@ -421,9 +421,6 @@ Copyright (c) 2011, Terrence Parr. This product includes/uses ASM (http://asm.ow2.org/), Copyright (c) 2000-2007 INRIA, France Telecom. -This product includes/uses org.json (http://www.json.org/java/index.html), -Copyright (c) 2002 JSON.org - This product includes/uses JLine (http://jline.sourceforge.net/), Copyright (c) 2002-2006, Marc Prud'hommeaux . diff --git a/R/CRAN_RELEASE.md b/R/CRAN_RELEASE.md new file mode 100644 index 0000000000000..d6084c7a7cc90 --- /dev/null +++ b/R/CRAN_RELEASE.md @@ -0,0 +1,91 @@ +# SparkR CRAN Release + +To release SparkR as a package to CRAN, we would use the `devtools` package. Please work with the +`dev@spark.apache.org` community and R package maintainer on this. + +### Release + +First, check that the `Version:` field in the `pkg/DESCRIPTION` file is updated. Also, check for stale files not under source control. + +Note that while `run-tests.sh` runs `check-cran.sh` (which runs `R CMD check`), it is doing so with `--no-manual --no-vignettes`, which skips a few vignettes or PDF checks - therefore it will be preferred to run `R CMD check` on the source package built manually before uploading a release. Also note that for CRAN checks for pdf vignettes to success, `qpdf` tool must be there (to install it, eg. `yum -q -y install qpdf`). + +To upload a release, we would need to update the `cran-comments.md`. This should generally contain the results from running the `check-cran.sh` script along with comments on status of all `WARNING` (should not be any) or `NOTE`. As a part of `check-cran.sh` and the release process, the vignettes is build - make sure `SPARK_HOME` is set and Spark jars are accessible. + +Once everything is in place, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::release(); .libPaths(paths) +``` + +For more information please refer to http://r-pkgs.had.co.nz/release.html#release-check + +### Testing: build package manually + +To build package manually such as to inspect the resulting `.tar.gz` file content, we would also use the `devtools` package. + +Source package is what get released to CRAN. CRAN would then build platform-specific binary packages from the source package. + +#### Build source package + +To build source package locally without releasing to CRAN, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg"); .libPaths(paths) +``` + +(http://r-pkgs.had.co.nz/vignettes.html#vignette-workflow-2) + +Similarly, the source package is also created by `check-cran.sh` with `R CMD build pkg`. + +For example, this should be the content of the source package: + +```sh +DESCRIPTION R inst tests +NAMESPACE build man vignettes + +inst/doc/ +sparkr-vignettes.html +sparkr-vignettes.Rmd +sparkr-vignettes.Rman + +build/ +vignette.rds + +man/ + *.Rd files... + +vignettes/ +sparkr-vignettes.Rmd +``` + +#### Test source package + +To install, run this: + +```sh +R CMD INSTALL SparkR_2.1.0.tar.gz +``` + +With "2.1.0" replaced with the version of SparkR. + +This command installs SparkR to the default libPaths. Once that is done, you should be able to start R and run: + +```R +library(SparkR) +vignette("sparkr-vignettes", package="SparkR") +``` + +#### Build binary package + +To build binary package locally, run in R under the `SPARK_HOME/R` directory: + +```R +paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); devtools::build("pkg", binary = TRUE); .libPaths(paths) +``` + +For example, this should be the content of the binary package: + +```sh +DESCRIPTION Meta R html tests +INDEX NAMESPACE help profile worker +``` diff --git a/R/README.md b/R/README.md index 932d5272d0b4f..4c40c5963db70 100644 --- a/R/README.md +++ b/R/README.md @@ -6,7 +6,7 @@ SparkR is an R package that provides a light-weight frontend to use Spark from R Libraries of sparkR need to be created in `$SPARK_HOME/R/lib`. This can be done by running the script `$SPARK_HOME/R/install-dev.sh`. By default the above script uses the system wide installation of R. However, this can be changed to any user installed location of R by setting the environment variable `R_HOME` the full path of the base directory where R is installed, before running install-dev.sh script. -Example: +Example: ```bash # where /home/username/R is where R is installed and /home/username/R/bin contains the files R and RScript export R_HOME=/home/username/R @@ -46,19 +46,19 @@ Sys.setenv(SPARK_HOME="/Users/username/spark") # This line loads SparkR from the installed directory .libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths())) library(SparkR) -sc <- sparkR.init(master="local") +sparkR.session() ``` #### Making changes to SparkR -The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR. +The [instructions](http://spark.apache.org/contributing.html) for making contributions to Spark also apply to SparkR. If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes. Once you have made your changes, please include unit tests for them and run existing unit tests using the `R/run-tests.sh` script as described below. - + #### Generating documentation The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script. Also, you may need to install these [prerequisites](https://github.com/apache/spark/tree/master/docs#prerequisites). See also, `R/DOCUMENTATION.md` - + ### Examples, Unit tests SparkR comes with several sample programs in the `examples/src/main/r` directory. diff --git a/R/WINDOWS.md b/R/WINDOWS.md index 1afcbfcabe85f..9ca7e58e20cd2 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -6,7 +6,7 @@ To build SparkR on Windows, the following steps are required include Rtools and R in `PATH`. 2. Install -[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +[JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set `JAVA_HOME` in the system environment variables. 3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` @@ -38,6 +38,6 @@ To run the SparkR unit tests on Windows, the following steps are required —ass ``` R -e "install.packages('testthat', repos='http://cran.us.r-project.org')" - .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + .\bin\spark-submit2.cmd --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R ``` diff --git a/R/check-cran.sh b/R/check-cran.sh index bb331466ae931..22cc9c6b601fc 100755 --- a/R/check-cran.sh +++ b/R/check-cran.sh @@ -20,30 +20,36 @@ set -o pipefail set -e -FWDIR="$(cd `dirname $0`; pwd)" -pushd $FWDIR > /dev/null +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null -if [ ! -z "$R_HOME" ] - then - R_SCRIPT_PATH="$R_HOME/bin" - else - # if system wide R_HOME is not found, then exit - if [ ! `command -v R` ]; then - echo "Cannot find 'R_HOME'. Please specify 'R_HOME' or make sure R is properly installed." - exit 1 - fi - R_SCRIPT_PATH="$(dirname $(which R))" +. "$FWDIR/find-r.sh" + +# Install the package (this is required for code in vignettes to run when building it later) +# Build the latest docs, but not vignettes, which is built with the package next +. "$FWDIR/install-dev.sh" + +# Build source package with vignettes +SPARK_HOME="$(cd "${FWDIR}"/..; pwd)" +. "${SPARK_HOME}/bin/load-spark-env.sh" +if [ -f "${SPARK_HOME}/RELEASE" ]; then + SPARK_JARS_DIR="${SPARK_HOME}/jars" +else + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" fi -echo "USING R_HOME = $R_HOME" -# Build the latest docs -$FWDIR/create-docs.sh +if [ -d "$SPARK_JARS_DIR" ]; then + # Build a zip file containing the source package with vignettes + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD build "$FWDIR/pkg" -# Build a zip file containing the source package -"$R_SCRIPT_PATH/"R CMD build $FWDIR/pkg + find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete +else + echo "Error Spark JARs not found in '$SPARK_HOME'" + exit 1 +fi # Run check as-cran. -VERSION=`grep Version $FWDIR/pkg/DESCRIPTION | awk '{print $NF}'` +VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` CRAN_CHECK_OPTIONS="--as-cran" @@ -54,11 +60,17 @@ fi if [ -n "$NO_MANUAL" ] then - CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual" + CRAN_CHECK_OPTIONS=$CRAN_CHECK_OPTIONS" --no-manual --no-vignettes" fi echo "Running CRAN check with $CRAN_CHECK_OPTIONS options" -"$R_SCRIPT_PATH/"R CMD check $CRAN_CHECK_OPTIONS SparkR_"$VERSION".tar.gz +if [ -n "$NO_TESTS" ] && [ -n "$NO_MANUAL" ] +then + "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" +else + # This will run tests and/or build vignettes, and require SPARK_HOME + SPARK_HOME="${SPARK_HOME}" "$R_SCRIPT_PATH/R" CMD check $CRAN_CHECK_OPTIONS "SparkR_$VERSION.tar.gz" +fi popd > /dev/null diff --git a/R/create-docs.sh b/R/create-docs.sh index 69ffc5f678c36..310dbc5fb50a3 100755 --- a/R/create-docs.sh +++ b/R/create-docs.sh @@ -20,7 +20,7 @@ # Script to create API docs and vignettes for SparkR # This requires `devtools`, `knitr` and `rmarkdown` to be installed on the machine. -# After running this script the html docs can be found in +# After running this script the html docs can be found in # $SPARK_HOME/R/pkg/html # The vignettes can be found in # $SPARK_HOME/R/pkg/vignettes/sparkr_vignettes.html @@ -29,18 +29,19 @@ set -o pipefail set -e # Figure out where the script is -export FWDIR="$(cd "`dirname "$0"`"; pwd)" -export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +export FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +export SPARK_HOME="$(cd "`dirname "${BASH_SOURCE[0]}"`"/..; pwd)" # Required for setting SPARK_SCALA_VERSION -. "${SPARK_HOME}"/bin/load-spark-env.sh +. "${SPARK_HOME}/bin/load-spark-env.sh" echo "Using Scala $SPARK_SCALA_VERSION" -pushd $FWDIR +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" # Install the package (this will also generate the Rd files) -./install-dev.sh +. "$FWDIR/install-dev.sh" # Now create HTML files @@ -48,25 +49,8 @@ pushd $FWDIR mkdir -p pkg/html pushd pkg/html -Rscript -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' +"$R_SCRIPT_PATH/Rscript" -e 'libDir <- "../../lib"; library(SparkR, lib.loc=libDir); library(knitr); knit_rd("SparkR", links = tools::findHTMLlinks(paste(libDir, "SparkR", sep="/")))' popd -# Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/jars" -else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" -fi - -# Only create vignettes if Spark JARs exist -if [ -d "$SPARK_JARS_DIR" ]; then - # render creates SparkR vignettes - Rscript -e 'library(rmarkdown); paths <- .libPaths(); .libPaths(c("lib", paths)); Sys.setenv(SPARK_HOME=tools::file_path_as_absolute("..")); render("pkg/vignettes/sparkr-vignettes.Rmd"); .libPaths(paths)' - - find pkg/vignettes/. -not -name '.' -not -name '*.Rmd' -not -name '*.md' -not -name '*.pdf' -not -name '*.html' -delete -else - echo "Skipping R vignettes as Spark JARs not found in $SPARK_HOME" -fi - popd diff --git a/R/create-rd.sh b/R/create-rd.sh new file mode 100755 index 0000000000000..ff622a41a46c0 --- /dev/null +++ b/R/create-rd.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + +set -o pipefail +set -e + +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" + +# Generate Rd files if devtools is installed +"$R_SCRIPT_PATH/Rscript" -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' diff --git a/external/java8-tests/src/test/resources/log4j.properties b/R/find-r.sh old mode 100644 new mode 100755 similarity index 61% rename from external/java8-tests/src/test/resources/log4j.properties rename to R/find-r.sh index 3706a6e361307..690acc083af91 --- a/external/java8-tests/src/test/resources/log4j.properties +++ b/R/find-r.sh @@ -1,3 +1,5 @@ +#!/bin/bash + # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -15,13 +17,18 @@ # limitations under the License. # -# Set everything to be logged to the file target/unit-tests.log -log4j.rootCategory=INFO, file -log4j.appender.file=org.apache.log4j.FileAppender -log4j.appender.file.append=true -log4j.appender.file.file=target/unit-tests.log -log4j.appender.file.layout=org.apache.log4j.PatternLayout -log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n - -# Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark_project.jetty=WARN +if [ -z "$R_SCRIPT_PATH" ] +then + if [ ! -z "$R_HOME" ] + then + R_SCRIPT_PATH="$R_HOME/bin" + else + # if system wide R_HOME is not found, then exit + if [ ! `command -v R` ]; then + echo "Cannot find 'R_HOME'. Please specify 'R_HOME' or make sure R is properly installed." + exit 1 + fi + R_SCRIPT_PATH="$(dirname $(which R))" + fi + echo "Using R_SCRIPT_PATH = ${R_SCRIPT_PATH}" +fi diff --git a/R/install-dev.sh b/R/install-dev.sh index ada6303a722b7..d613552718307 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -29,33 +29,21 @@ set -o pipefail set -e -FWDIR="$(cd `dirname $0`; pwd)" +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" -mkdir -p $LIB_DIR - -pushd $FWDIR > /dev/null -if [ ! -z "$R_HOME" ] - then - R_SCRIPT_PATH="$R_HOME/bin" - else - # if system wide R_HOME is not found, then exit - if [ ! `command -v R` ]; then - echo "Cannot find 'R_HOME'. Please specify 'R_HOME' or make sure R is properly installed." - exit 1 - fi - R_SCRIPT_PATH="$(dirname $(which R))" -fi -echo "USING R_HOME = $R_HOME" - -# Generate Rd files if devtools is installed -"$R_SCRIPT_PATH/"Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }' +mkdir -p "$LIB_DIR" + +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" + +. "$FWDIR/create-rd.sh" # Install SparkR to $LIB_DIR -"$R_SCRIPT_PATH/"R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/ +"$R_SCRIPT_PATH/R" CMD INSTALL --library="$LIB_DIR" "$FWDIR/pkg/" # Zip the SparkR package so that it can be distributed to worker nodes on YARN -cd $LIB_DIR +cd "$LIB_DIR" jar cfM "$LIB_DIR/sparkr.zip" SparkR popd > /dev/null diff --git a/R/install-source-package.sh b/R/install-source-package.sh new file mode 100755 index 0000000000000..8de3569d1d482 --- /dev/null +++ b/R/install-source-package.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This scripts packages the SparkR source files (R and C files) and +# creates a package that can be loaded in R. The package is by default installed to +# $FWDIR/lib and the package can be loaded by using the following command in R: +# +# library(SparkR, lib.loc="$FWDIR/lib") +# +# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory +# to load the SparkR package on the worker nodes. + +set -o pipefail +set -e + +FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" +pushd "$FWDIR" > /dev/null +. "$FWDIR/find-r.sh" + +if [ -z "$VERSION" ]; then + VERSION=`grep Version "$FWDIR/pkg/DESCRIPTION" | awk '{print $NF}'` +fi + +if [ ! -f "$FWDIR/SparkR_$VERSION.tar.gz" ]; then + echo -e "R source package file '$FWDIR/SparkR_$VERSION.tar.gz' is not found." + echo -e "Please build R source package with check-cran.sh" + exit -1; +fi + +echo "Removing lib path and installing from source package" +LIB_DIR="$FWDIR/lib" +rm -rf "$LIB_DIR" +mkdir -p "$LIB_DIR" +"$R_SCRIPT_PATH/R" CMD INSTALL "SparkR_$VERSION.tar.gz" --library="$LIB_DIR" + +# Zip the SparkR package so that it can be distributed to worker nodes on YARN +pushd "$LIB_DIR" > /dev/null +jar cfM "$LIB_DIR/sparkr.zip" SparkR +popd > /dev/null + +popd diff --git a/R/pkg/.Rbuildignore b/R/pkg/.Rbuildignore index 544d203a6dce6..f12f8c275a989 100644 --- a/R/pkg/.Rbuildignore +++ b/R/pkg/.Rbuildignore @@ -1,5 +1,8 @@ ^.*\.Rproj$ ^\.Rproj\.user$ ^\.lintr$ +^cran-comments\.md$ +^NEWS\.md$ +^README\.Rmd$ ^src-native$ ^html$ diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 038236fc149e6..ae50b28ec6166 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 5a83883089e0e..879c1f80f2c5d 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,8 +1,8 @@ Package: SparkR Type: Package +Version: 2.2.0 Title: R Frontend for Apache Spark -Version: 2.0.0 -Date: 2016-08-27 +Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), email = "shivaram@cs.berkeley.edu"), person("Xiangrui", "Meng", role = "aut", @@ -10,17 +10,18 @@ Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), person("Felix", "Cheung", role = "aut", email = "felixcheung@apache.org"), person(family = "The Apache Software Foundation", role = c("aut", "cph"))) +License: Apache License (== 2.0) URL: http://www.apache.org/ http://spark.apache.org/ -BugReports: https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark#ContributingtoSpark-ContributingBugReports +BugReports: http://spark.apache.org/contributing.html Depends: R (>= 3.0), methods Suggests: + knitr, + rmarkdown, testthat, e1071, survival -Description: The SparkR package provides an R frontend for Apache Spark. -License: Apache License (== 2.0) Collate: 'schema.R' 'generics.R' @@ -34,17 +35,27 @@ Collate: 'WindowSpec.R' 'backend.R' 'broadcast.R' + 'catalog.R' 'client.R' 'context.R' 'deserialize.R' 'functions.R' 'install.R' 'jvm.R' - 'mllib.R' + 'mllib_classification.R' + 'mllib_clustering.R' + 'mllib_fpm.R' + 'mllib_recommendation.R' + 'mllib_regression.R' + 'mllib_stat.R' + 'mllib_tree.R' + 'mllib_utils.R' 'serialize.R' 'sparkR.R' 'stats.R' + 'streaming.R' 'types.R' 'utils.R' 'window.R' RoxygenNote: 5.0.1 +VignetteBuilder: knitr diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 9cd6269f9a8f7..5c074d3c0fd40 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -1,9 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + # Imports from base R # Do not include stats:: "rpois", "runif" - causes error at runtime importFrom("methods", "setGeneric", "setMethod", "setOldClass") importFrom("methods", "is", "new", "signature", "show") importFrom("stats", "gaussian", "setNames") -importFrom("utils", "download.file", "object.size", "packageVersion", "untar") +importFrom("utils", "download.file", "object.size", "packageVersion", "tail", "untar") # Disable native libraries till we figure out how to package it # See SPARKR-7839 @@ -16,6 +33,7 @@ export("sparkR.stop") export("sparkR.session.stop") export("sparkR.conf") export("sparkR.version") +export("sparkR.uiWebUrl") export("print.jobj") export("sparkR.newJObject") @@ -45,7 +63,13 @@ exportMethods("glm", "spark.als", "spark.kstest", "spark.logit", - "spark.randomForest") + "spark.randomForest", + "spark.gbt", + "spark.bisectingKmeans", + "spark.svmLinear", + "spark.fpGrowth", + "spark.freqItemsets", + "spark.associationRules") # Job group lifecycle management methods export("setJobGroup", @@ -60,7 +84,10 @@ exportClasses("SparkDataFrame") exportMethods("arrange", "as.data.frame", "attach", + "broadcast", "cache", + "checkpoint", + "coalesce", "collect", "colnames", "colnames<-", @@ -75,6 +102,7 @@ exportMethods("arrange", "createOrReplaceTempView", "crossJoin", "crosstab", + "cube", "dapply", "dapplyCollect", "describe", @@ -92,12 +120,15 @@ exportMethods("arrange", "freqItems", "gapply", "gapplyCollect", + "getNumPartitions", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", + "isStreaming", "join", "limit", "merge", @@ -115,6 +146,7 @@ exportMethods("arrange", "registerTempTable", "rename", "repartition", + "rollup", "sample", "sample_frac", "sampleBy", @@ -131,6 +163,7 @@ exportMethods("arrange", "summarize", "summary", "take", + "toJSON", "transform", "union", "unionAll", @@ -145,12 +178,14 @@ exportMethods("arrange", "write.json", "write.orc", "write.parquet", + "write.stream", "write.text", "write.ml") exportClasses("Column") -exportMethods("%in%", +exportMethods("%<=>%", + "%in%", "abs", "acos", "add_months", @@ -173,6 +208,8 @@ exportMethods("%in%", "cbrt", "ceil", "ceiling", + "collect_list", + "collect_set", "column", "concat", "concat_ws", @@ -183,6 +220,8 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "create_array", + "create_map", "hash", "cume_dist", "date_add", @@ -198,6 +237,7 @@ exportMethods("%in%", "endsWith", "exp", "explode", + "explode_outer", "expm1", "expr", "factorial", @@ -205,17 +245,21 @@ exportMethods("%in%", "floor", "format_number", "format_string", + "from_json", "from_unixtime", "from_utc_timestamp", "getField", "getItem", "greatest", + "grouping_bit", + "grouping_id", "hex", "histogram", "hour", "hypot", "ifelse", "initcap", + "input_file_name", "instr", "isNaN", "isNotNull", @@ -253,18 +297,21 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "not", "ntile", "otherwise", "over", "percent_rank", "pmod", "posexplode", + "posexplode_outer", "quarter", "rand", "randn", "rank", "regexp_extract", "regexp_replace", + "repeat_string", "reverse", "rint", "rlike", @@ -288,6 +335,7 @@ exportMethods("%in%", "sort_array", "soundex", "spark_partition_id", + "split_string", "stddev", "stddev_pop", "stddev_samp", @@ -303,6 +351,8 @@ exportMethods("%in%", "toDegrees", "toRadians", "to_date", + "to_json", + "to_timestamp", "to_utc_timestamp", "translate", "trim", @@ -328,9 +378,15 @@ export("as.DataFrame", "clearCache", "createDataFrame", "createExternalTable", + "createTable", + "currentDatabase", "dropTempTable", "dropTempView", "jsonFile", + "listColumns", + "listDatabases", + "listFunctions", + "listTables", "loadDF", "parquetFile", "read.df", @@ -338,7 +394,13 @@ export("as.DataFrame", "read.json", "read.orc", "read.parquet", + "read.stream", "read.text", + "recoverPartitions", + "refreshByPath", + "refreshTable", + "setCheckpointDir", + "setCurrentDatabase", "spark.lapply", "spark.addFile", "spark.getSparkFilesRootDirectory", @@ -353,7 +415,9 @@ export("as.DataFrame", "read.ml", "print.summary.KSTest", "print.summary.RandomForestRegressionModel", - "print.summary.RandomForestClassificationModel") + "print.summary.RandomForestClassificationModel", + "print.summary.GBTRegressionModel", + "print.summary.GBTClassificationModel") export("structField", "structField.jobj", @@ -373,6 +437,16 @@ export("partitionBy", export("windowPartitionBy", "windowOrderBy") +exportClasses("StreamingQuery") + +export("awaitTermination", + "isActive", + "lastProgress", + "queryName", + "status", + "stopQuery") + + S3method(print, jobj) S3method(print, structField) S3method(print, structType) @@ -380,6 +454,8 @@ S3method(print, summary.GeneralizedLinearRegressionModel) S3method(print, summary.KSTest) S3method(print, summary.RandomForestRegressionModel) S3method(print, summary.RandomForestClassificationModel) +S3method(print, summary.GBTRegressionModel) +S3method(print, summary.GBTClassificationModel) S3method(structField, character) S3method(structField, jobj) S3method(structType, jobj) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 1cf9b38ea6483..aab2fc17aedaf 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -133,9 +133,6 @@ setMethod("schema", #' #' Print the logical and physical Catalyst plans to the console for debugging. #' -#' @param x a SparkDataFrame. -#' @param extended Logical. If extended is FALSE, explain() only prints the physical plan. -#' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases explain,SparkDataFrame-method #' @rdname explain @@ -197,6 +194,7 @@ setMethod("isLocal", #' 20 characters will be truncated. However, if set greater than zero, #' truncates strings longer than \code{truncate} characters and all cells #' will be aligned right. +#' @param vertical whether print output rows vertically (one line per column value). #' @param ... further arguments to be passed to or from other methods. #' @family SparkDataFrame functions #' @aliases showDF,SparkDataFrame-method @@ -213,12 +211,13 @@ setMethod("isLocal", #' @note showDF since 1.4.0 setMethod("showDF", signature(x = "SparkDataFrame"), - function(x, numRows = 20, truncate = TRUE) { + function(x, numRows = 20, truncate = TRUE, vertical = FALSE) { if (is.logical(truncate) && truncate) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(20), vertical) } else { truncate2 <- as.numeric(truncate) - s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2)) + s <- callJMethod(x@sdf, "showString", numToInt(numRows), numToInt(truncate2), + vertical) } cat(s) }) @@ -280,7 +279,7 @@ setMethod("dtypes", #' Column Names of SparkDataFrame #' -#' Return all column names as a list. +#' Return a vector of column names. #' #' @param x a SparkDataFrame. #' @@ -323,10 +322,8 @@ setMethod("names", setMethod("names<-", signature(x = "SparkDataFrame"), function(x, value) { - if (!is.null(value)) { - sdf <- callJMethod(x@sdf, "toDF", as.list(value)) - dataFrame(sdf) - } + colnames(x) <- value + x }) #' @rdname columns @@ -340,7 +337,7 @@ setMethod("colnames", }) #' @param value a character vector. Must have the same length as the number -#' of columns in the SparkDataFrame. +#' of columns to be renamed. #' @rdname columns #' @aliases colnames<-,SparkDataFrame-method #' @name colnames<- @@ -417,7 +414,7 @@ setMethod("coltypes", type <- PRIMITIVE_TYPES[[specialtype]] } } - type + type[[1]] }) # Find which types don't have mapping to R @@ -562,7 +559,7 @@ setMethod("insertInto", jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) write <- callJMethod(x@sdf, "write") write <- callJMethod(write, "mode", jmode) - callJMethod(write, "insertInto", tableName) + invisible(callJMethod(write, "insertInto", tableName)) }) #' Cache @@ -680,14 +677,53 @@ setMethod("storageLevel", storageLevelToString(callJMethod(x@sdf, "storageLevel")) }) +#' Coalesce +#' +#' Returns a new SparkDataFrame that has exactly \code{numPartitions} partitions. +#' This operation results in a narrow dependency, e.g. if you go from 1000 partitions to 100 +#' partitions, there will not be a shuffle, instead each of the 100 new partitions will claim 10 of +#' the current partitions. If a larger number of partitions is requested, it will stay at the +#' current number of partitions. +#' +#' However, if you're doing a drastic coalesce on a SparkDataFrame, e.g. to numPartitions = 1, +#' this may result in your computation taking place on fewer nodes than +#' you like (e.g. one node in the case of numPartitions = 1). To avoid this, +#' call \code{repartition}. This will add a shuffle step, but means the +#' current upstream partitions will be executed in parallel (per whatever +#' the current partitioning is). +#' +#' @param numPartitions the number of partitions to use. +#' +#' @family SparkDataFrame functions +#' @rdname coalesce +#' @name coalesce +#' @aliases coalesce,SparkDataFrame-method +#' @seealso \link{repartition} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' newDF <- coalesce(df, 1L) +#'} +#' @note coalesce(SparkDataFrame) since 2.1.1 +setMethod("coalesce", + signature(x = "SparkDataFrame"), + function(x, numPartitions) { + stopifnot(is.numeric(numPartitions)) + sdf <- callJMethod(x@sdf, "coalesce", numToInt(numPartitions)) + dataFrame(sdf) + }) + #' Repartition #' #' The following options for repartition are possible: #' \itemize{ -#' \item{1.} {Return a new SparkDataFrame partitioned by +#' \item{1.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.} +#' \item{2.} {Return a new SparkDataFrame hash partitioned by #' the given columns into \code{numPartitions}.} -#' \item{2.} {Return a new SparkDataFrame that has exactly \code{numPartitions}.} -#' \item{3.} {Return a new SparkDataFrame partitioned by the given column(s), +#' \item{3.} {Return a new SparkDataFrame hash partitioned by the given column(s), #' using \code{spark.sql.shuffle.partitions} as number of partitions.} #'} #' @param x a SparkDataFrame. @@ -699,6 +735,7 @@ setMethod("storageLevel", #' @rdname repartition #' @name repartition #' @aliases repartition,SparkDataFrame-method +#' @seealso \link{coalesce} #' @export #' @examples #'\dontrun{ @@ -737,26 +774,32 @@ setMethod("repartition", #' toJSON #' -#' Convert the rows of a SparkDataFrame into JSON objects and return an RDD where -#' each element contains a JSON string. +#' Converts a SparkDataFrame into a SparkDataFrame of JSON string. #' -#' @param x A SparkDataFrame -#' @return A StringRRDD of JSON objects +#' Each row is turned into a JSON document with columns as different fields. +#' The returned SparkDataFrame has a single character column with the name \code{value} +#' +#' @param x a SparkDataFrame +#' @return a SparkDataFrame +#' @family SparkDataFrame functions +#' @rdname toJSON +#' @name toJSON #' @aliases toJSON,SparkDataFrame-method -#' @noRd +#' @export #' @examples #'\dontrun{ #' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' newRDD <- toJSON(df) +#' path <- "path/to/file.parquet" +#' df <- read.parquet(path) +#' df_json <- toJSON(df) #'} +#' @note toJSON since 2.2.0 setMethod("toJSON", signature(x = "SparkDataFrame"), function(x) { - rdd <- callJMethod(x@sdf, "toJSON") - jrdd <- callJMethod(rdd, "toJavaRDD") - RDD(jrdd, serializedMode = "string") + jsonDS <- callJMethod(x@sdf, "toJSON") + df <- callJMethod(jsonDS, "toDF") + dataFrame(df) }) #' Save the contents of SparkDataFrame as a JSON file @@ -937,6 +980,8 @@ setMethod("unique", #' Sample #' #' Return a sampled subset of this SparkDataFrame using a random seed. +#' Note: this is not guaranteed to provide exactly the fraction specified +#' of the total count of of the given SparkDataFrame. #' #' @param x A SparkDataFrame #' @param withReplacement Sampling with replacement or not @@ -1130,6 +1175,7 @@ setMethod("collect", if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") + class(vec) <- PRIMITIVE_TYPES[[colType]] df[[colIndex]] <- vec } else { df[[colIndex]] <- col @@ -1277,7 +1323,7 @@ setMethod("toRDD", #' Groups the SparkDataFrame using the specified columns, so we can run aggregation on them. #' #' @param x a SparkDataFrame. -#' @param ... variable(s) (character names(s) or Column(s)) to group on. +#' @param ... character name(s) or Column(s) to group on. #' @return A GroupedData. #' @family SparkDataFrame functions #' @aliases groupBy,SparkDataFrame-method @@ -1293,6 +1339,7 @@ setMethod("toRDD", #' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max") #' } #' @note groupBy since 1.4.0 +#' @seealso \link{agg}, \link{cube}, \link{rollup} setMethod("groupBy", signature(x = "SparkDataFrame"), function(x, ...) { @@ -1709,6 +1756,23 @@ getColumn <- function(x, c) { column(callJMethod(x@sdf, "col", c)) } +setColumn <- function(x, c, value) { + if (class(value) != "Column" && !is.null(value)) { + if (isAtomicLengthOne(value)) { + value <- lit(value) + } else { + stop("value must be a Column, literal value as atomic in length of 1, or NULL") + } + } + + if (is.null(value)) { + nx <- drop(x, c) + } else { + nx <- withColumn(x, c, value) + } + nx +} + #' @param name name of a Column (without being wrapped by \code{""}). #' @rdname select #' @name $ @@ -1719,20 +1783,15 @@ setMethod("$", signature(x = "SparkDataFrame"), getColumn(x, name) }) -#' @param value a Column or \code{NULL}. If \code{NULL}, the specified Column is dropped. +#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. +#' If \code{NULL}, the specified Column is dropped. #' @rdname select #' @name $<- #' @aliases $<-,SparkDataFrame-method #' @note $<- since 1.4.0 setMethod("$<-", signature(x = "SparkDataFrame"), function(x, name, value) { - stopifnot(class(value) == "Column" || is.null(value)) - - if (is.null(value)) { - nx <- drop(x, name) - } else { - nx <- withColumn(x, name, value) - } + nx <- setColumn(x, name, value) x@sdf <- nx@sdf x }) @@ -1745,6 +1804,10 @@ setClassUnion("numericOrcharacter", c("numeric", "character")) #' @note [[ since 1.4.0 setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), function(x, i) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } if (is.numeric(i)) { cols <- columns(x) i <- cols[[i]] @@ -1752,6 +1815,25 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), getColumn(x, i) }) +#' @rdname subset +#' @name [[<- +#' @aliases [[<-,SparkDataFrame,numericOrcharacter-method +#' @note [[<- since 2.1.1 +setMethod("[[<-", signature(x = "SparkDataFrame", i = "numericOrcharacter"), + function(x, i, value) { + if (length(i) > 1) { + warning("Subset index has length > 1. Only the first index is used.") + i <- i[1] + } + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + nx <- setColumn(x, i, value) + x@sdf <- nx@sdf + x + }) + #' @rdname subset #' @name [ #' @aliases [,SparkDataFrame-method @@ -1796,14 +1878,19 @@ setMethod("[", signature(x = "SparkDataFrame"), #' Return subsets of SparkDataFrame according to given conditions #' @param x a SparkDataFrame. #' @param i,subset (Optional) a logical expression to filter on rows. +#' For extract operator [[ and replacement operator [[<-, the indexing parameter for +#' a single Column. #' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. +#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. +#' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method +#' @seealso \link{withColumn} #' @rdname subset #' @name subset #' @family subsetting functions @@ -1821,6 +1908,10 @@ setMethod("[", signature(x = "SparkDataFrame"), #' subset(df, df$age %in% c(19, 30), 1:2) #' subset(df, df$age %in% c(19), select = c(1,2)) #' subset(df, select = c(1,2)) +#' # Columns can be selected and set +#' df[["age"]] <- 23 +#' df[[1]] <- df$age +#' df[[2]] <- NULL # drop column #' } #' @note subset since 1.5.0 setMethod("subset", signature(x = "SparkDataFrame"), @@ -1939,13 +2030,13 @@ setMethod("selectExpr", #' #' @param x a SparkDataFrame. #' @param colName a column name. -#' @param col a Column expression. +#' @param col a Column expression, or an atomic vector in the length of 1 as literal value. #' @return A SparkDataFrame with the new column added or the existing column replaced. #' @family SparkDataFrame functions -#' @aliases withColumn,SparkDataFrame,character,Column-method +#' @aliases withColumn,SparkDataFrame,character-method #' @rdname withColumn #' @name withColumn -#' @seealso \link{rename} \link{mutate} +#' @seealso \link{rename} \link{mutate} \link{subset} #' @export #' @examples #'\dontrun{ @@ -1955,11 +2046,20 @@ setMethod("selectExpr", #' newDF <- withColumn(df, "newCol", df$col1 * 5) #' # Replace an existing column #' newDF2 <- withColumn(newDF, "newCol", newDF$col1) +#' newDF3 <- withColumn(newDF, "newCol", 42) +#' # Use extract operator to set an existing or new column +#' df[["age"]] <- 23 +#' df[[2]] <- df$col1 +#' df[[2]] <- NULL # drop column #' } #' @note withColumn since 1.4.0 setMethod("withColumn", - signature(x = "SparkDataFrame", colName = "character", col = "Column"), + signature(x = "SparkDataFrame", colName = "character"), function(x, colName, col) { + if (class(col) != "Column") { + if (!isAtomicLengthOne(col)) stop("Literal value must be atomic in length of 1") + col <- lit(col) + } sdf <- callJMethod(x@sdf, "withColumn", colName, col@jc) dataFrame(sdf) }) @@ -2305,9 +2405,9 @@ setMethod("dropDuplicates", #' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a #' Column expression. If joinExpr is omitted, the default, inner join is attempted and an error is #' thrown if it would be a Cartesian Product. For Cartesian join, use crossJoin instead. -#' @param joinType The type of join to perform. The following join types are available: -#' 'inner', 'outer', 'full', 'fullouter', leftouter', 'left_outer', 'left', -#' 'right_outer', 'rightouter', 'right', and 'leftsemi'. The default joinType is "inner". +#' @param joinType The type of join to perform, default 'inner'. +#' Must be one of: 'inner', 'cross', 'outer', 'full', 'full_outer', +#' 'left', 'left_outer', 'right', 'right_outer', 'left_semi', or 'left_anti'. #' @return A SparkDataFrame containing the result of the join operation. #' @family SparkDataFrame functions #' @aliases join,SparkDataFrame,SparkDataFrame-method @@ -2336,15 +2436,18 @@ setMethod("join", if (is.null(joinType)) { sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc) } else { - if (joinType %in% c("inner", "outer", "full", "fullouter", - "leftouter", "left_outer", "left", - "rightouter", "right_outer", "right", "leftsemi")) { + if (joinType %in% c("inner", "cross", + "outer", "full", "fullouter", "full_outer", + "left", "leftouter", "left_outer", + "right", "rightouter", "right_outer", + "left_semi", "leftsemi", "left_anti", "leftanti")) { joinType <- gsub("_", "", joinType) sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType) } else { stop("joinType must be one of the following types: ", - "'inner', 'outer', 'full', 'fullouter', 'leftouter', 'left_outer', 'left', - 'rightouter', 'right_outer', 'right', 'leftsemi'") + "'inner', 'cross', 'outer', 'full', 'full_outer',", + "'left', 'left_outer', 'right', 'right_outer',", + "'left_semi', or 'left_anti'.") } } } @@ -2539,7 +2642,9 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' Input SparkDataFrames can have different schemas (names and data types). +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x A SparkDataFrame #' @param y A SparkDataFrame @@ -2581,8 +2686,10 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' -#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. -#' Note that this does not remove duplicate rows across the two SparkDataFrames. +#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method +#' requires that the input SparkDataFrames have the same column names. +#' +#' Note: This does not remove duplicate rows across the two SparkDataFrames. #' #' @param x a SparkDataFrame. #' @param ... additional SparkDataFrame(s). @@ -2604,6 +2711,10 @@ setMethod("unionAll", setMethod("rbind", signature(... = "SparkDataFrame"), function(x, ..., deparse.level = 1) { + nm <- lapply(list(x, ...), names) + if (length(unique(nm)) != 1) { + stop("Names of input data frames are different.") + } if (nargs() == 3) { union(x, ...) } else { @@ -2710,14 +2821,14 @@ setMethod("write.df", signature(df = "SparkDataFrame"), function(df, path = NULL, source = NULL, mode = "error", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", "in 'spark.sql.sources.default' configuration by default.") } if (!is.character(mode)) { - stop("mode should be charactor or omitted. It is 'error' by default.") + stop("mode should be character or omitted. It is 'error' by default.") } if (is.null(source)) { source <- getDefaultSqlSource() @@ -2786,7 +2897,7 @@ setMethod("saveAsTable", write <- callJMethod(write, "format", source) write <- callJMethod(write, "mode", jmode) write <- callJMethod(write, "options", options) - callJMethod(write, "saveAsTable", tableName) + invisible(callJMethod(write, "saveAsTable", tableName)) }) #' summary @@ -2932,7 +3043,7 @@ setMethod("fillna", signature(x = "SparkDataFrame"), function(x, value, cols = NULL) { if (!(class(value) %in% c("integer", "numeric", "character", "list"))) { - stop("value should be an integer, numeric, charactor or named list.") + stop("value should be an integer, numeric, character or named list.") } if (class(value) == "list") { @@ -2944,7 +3055,7 @@ setMethod("fillna", # Check each item in the named list is of valid type lapply(value, function(v) { if (!(class(v) %in% c("integer", "numeric", "character"))) { - stop("Each item in value should be an integer, numeric or charactor.") + stop("Each item in value should be an integer, numeric or character.") } }) @@ -3381,3 +3492,309 @@ setMethod("randomSplit", } sapply(sdfs, dataFrame) }) + +#' getNumPartitions +#' +#' Return the number of partitions +#' +#' @param x A SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases getNumPartitions,SparkDataFrame-method +#' @rdname getNumPartitions +#' @name getNumPartitions +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createDataFrame(cars, numPartitions = 2) +#' getNumPartitions(df) +#' } +#' @note getNumPartitions since 2.1.1 +setMethod("getNumPartitions", + signature(x = "SparkDataFrame"), + function(x) { + callJMethod(callJMethod(x@sdf, "rdd"), "getNumPartitions") + }) + +#' isStreaming +#' +#' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data +#' as it arrives. +#' +#' @param x A SparkDataFrame +#' @return TRUE if this SparkDataFrame is from a streaming source +#' @family SparkDataFrame functions +#' @aliases isStreaming,SparkDataFrame-method +#' @rdname isStreaming +#' @name isStreaming +#' @seealso \link{read.stream} \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' } +#' @note isStreaming since 2.2.0 +#' @note experimental +setMethod("isStreaming", + signature(x = "SparkDataFrame"), + function(x) { + callJMethod(x@sdf, "isStreaming") + }) + +#' Write the streaming SparkDataFrame to a data source. +#' +#' The data source is specified by the \code{source} and a set of options (...). +#' If \code{source} is not specified, the default data source configured by +#' spark.sql.sources.default will be used. +#' +#' Additionally, \code{outputMode} specifies how data of a streaming SparkDataFrame is written to a +#' output data source. There are three modes: +#' \itemize{ +#' \item append: Only the new rows in the streaming SparkDataFrame will be written out. This +#' output mode can be only be used in queries that do not contain any aggregation. +#' \item complete: All the rows in the streaming SparkDataFrame will be written out every time +#' there are some updates. This output mode can only be used in queries that +#' contain aggregations. +#' \item update: Only the rows that were updated in the streaming SparkDataFrame will be written +#' out every time there are some updates. If the query doesn't contain aggregations, +#' it will be equivalent to \code{append} mode. +#' } +#' +#' @param df a streaming SparkDataFrame. +#' @param source a name for external data source. +#' @param outputMode one of 'append', 'complete', 'update'. +#' @param ... additional argument(s) passed to the method. +#' +#' @family SparkDataFrame functions +#' @seealso \link{read.stream} +#' @aliases write.stream,SparkDataFrame-method +#' @rdname write.stream +#' @name write.stream +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' isStreaming(df) +#' wordCounts <- count(group_by(df, "value")) +#' +#' # console +#' q <- write.stream(wordCounts, "console", outputMode = "complete") +#' # text stream +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' # memory stream +#' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") +#' head(sql("SELECT * from outs")) +#' queryName(q) +#' +#' stopQuery(q) +#' } +#' @note write.stream since 2.2.0 +#' @note experimental +setMethod("write.stream", + signature(df = "SparkDataFrame"), + function(df, source = NULL, outputMode = NULL, ...) { + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (!is.null(outputMode) && !is.character(outputMode)) { + stop("outputMode should be character or omitted.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + write <- handledCallJMethod(df@sdf, "writeStream") + write <- callJMethod(write, "format", source) + if (!is.null(outputMode)) { + write <- callJMethod(write, "outputMode", outputMode) + } + write <- callJMethod(write, "options", options) + ssq <- handledCallJMethod(write, "start") + streamingQuery(ssq) + }) + +#' checkpoint +#' +#' Returns a checkpointed version of this SparkDataFrame. Checkpointing can be used to truncate the +#' logical plan, which is especially useful in iterative algorithms where the plan may grow +#' exponentially. It will be saved to files inside the checkpoint directory set with +#' \code{setCheckpointDir} +#' +#' @param x A SparkDataFrame +#' @param eager whether to checkpoint this SparkDataFrame immediately +#' @return a new checkpointed SparkDataFrame +#' @family SparkDataFrame functions +#' @aliases checkpoint,SparkDataFrame-method +#' @rdname checkpoint +#' @name checkpoint +#' @seealso \link{setCheckpointDir} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#' df <- checkpoint(df) +#' } +#' @note checkpoint since 2.2.0 +setMethod("checkpoint", + signature(x = "SparkDataFrame"), + function(x, eager = TRUE) { + df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) + dataFrame(df) + }) + +#' cube +#' +#' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases cube,SparkDataFrame-method +#' @rdname cube +#' @name cube +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(cube(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(cube(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note cube since 2.3.0 +#' @seealso \link{agg}, \link{groupBy}, \link{rollup} +setMethod("cube", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "cube", jcol) + groupedData(sgd) + }) + +#' rollup +#' +#' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. +#' +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to +#' direct application of \link{agg}. +#' +#' @param x a SparkDataFrame. +#' @param ... character name(s) or Column(s) to group on. +#' @return A GroupedData. +#' @family SparkDataFrame functions +#' @aliases rollup,SparkDataFrame-method +#' @rdname rollup +#' @name rollup +#' @export +#' @examples +#'\dontrun{ +#' df <- createDataFrame(mtcars) +#' mean(rollup(df, "cyl", "gear", "am"), "mpg") +#' +#' # Following calls are equivalent +#' agg(rollup(carsDF), mean(carsDF$mpg)) +#' agg(carsDF, mean(carsDF$mpg)) +#' } +#' @note rollup since 2.3.0 +#' @seealso \link{agg}, \link{cube}, \link{groupBy} +setMethod("rollup", + signature(x = "SparkDataFrame"), + function(x, ...) { + cols <- list(...) + jcol <- lapply(cols, function(x) if (class(x) == "Column") x@jc else column(x)@jc) + sgd <- callJMethod(x@sdf, "rollup", jcol) + groupedData(sgd) + }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) + +#' alias +#' +#' @aliases alias,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname alias +#' @name alias +#' @export +#' @examples +#' \dontrun{ +#' df <- alias(createDataFrame(mtcars), "mtcars") +#' avg_mpg <- alias(agg(groupBy(df, df$cyl), avg(df$mpg)), "avg_mpg") +#' +#' head(select(df, column("mtcars.mpg"))) +#' head(join(df, avg_mpg, column("mtcars.cyl") == column("avg_mpg.cyl"))) +#' } +#' @note alias(SparkDataFrame) since 2.3.0 +setMethod("alias", + signature(object = "SparkDataFrame"), + function(object, data) { + stopifnot(is.character(data)) + sdf <- callJMethod(object@sdf, "alias", data) + dataFrame(sdf) + }) + +#' broadcast +#' +#' Return a new SparkDataFrame marked as small enough for use in broadcast joins. +#' +#' Equivalent to \code{hint(x, "broadcast")}. +#' +#' @param x a SparkDataFrame. +#' @return a SparkDataFrame. +#' +#' @aliases broadcast,SparkDataFrame-method +#' @family SparkDataFrame functions +#' @rdname broadcast +#' @name broadcast +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, broadcast(avg_mpg), df$cyl == avg_mpg$cyl)) +#' } +#' @note broadcast since 2.3.0 +setMethod("broadcast", + signature(x = "SparkDataFrame"), + function(x) { + sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) + dataFrame(sdf) + }) diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 0f1162fec1df9..7ad3993e9ecbc 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -291,7 +291,7 @@ setMethod("unpersistRDD", #' @rdname checkpoint-methods #' @aliases checkpoint,RDD-method #' @noRd -setMethod("checkpoint", +setMethod("checkpointRDD", signature(x = "RDD"), function(x) { jrdd <- getJRDD(x) @@ -313,7 +313,7 @@ setMethod("checkpoint", #' @rdname getNumPartitions #' @aliases getNumPartitions,RDD-method #' @noRd -setMethod("getNumPartitions", +setMethod("getNumPartitionsRDD", signature(x = "RDD"), function(x) { callJMethod(getJRDD(x), "getNumPartitions") @@ -329,7 +329,7 @@ setMethod("numPartitions", signature(x = "RDD"), function(x) { .Deprecated("getNumPartitions") - getNumPartitions(x) + getNumPartitionsRDD(x) }) #' Collect elements of an RDD @@ -460,7 +460,7 @@ setMethod("countByValue", signature(x = "RDD"), function(x) { ones <- lapply(x, function(item) { list(item, 1L) }) - collectRDD(reduceByKey(ones, `+`, getNumPartitions(x))) + collectRDD(reduceByKey(ones, `+`, getNumPartitionsRDD(x))) }) #' Apply a function to all elements @@ -780,7 +780,7 @@ setMethod("takeRDD", resList <- list() index <- -1 jrdd <- getJRDD(x) - numPartitions <- getNumPartitions(x) + numPartitions <- getNumPartitionsRDD(x) serializedModeRDD <- getSerializedMode(x) # TODO(shivaram): Collect more than one partition based on size @@ -846,7 +846,7 @@ setMethod("firstRDD", #' @noRd setMethod("distinctRDD", signature(x = "RDD"), - function(x, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, numPartitions = SparkR:::getNumPartitionsRDD(x)) { identical.mapped <- lapply(x, function(x) { list(x, NULL) }) reduced <- reduceByKey(identical.mapped, function(x, y) { x }, @@ -1028,7 +1028,7 @@ setMethod("repartitionRDD", signature(x = "RDD"), function(x, numPartitions) { if (!is.null(numPartitions) && is.numeric(numPartitions)) { - coalesce(x, numPartitions, TRUE) + coalesceRDD(x, numPartitions, TRUE) } else { stop("Please, specify the number of partitions") } @@ -1049,11 +1049,11 @@ setMethod("repartitionRDD", #' @rdname coalesce #' @aliases coalesce,RDD #' @noRd -setMethod("coalesce", +setMethod("coalesceRDD", signature(x = "RDD", numPartitions = "numeric"), function(x, numPartitions, shuffle = FALSE) { numPartitions <- numToInt(numPartitions) - if (shuffle || numPartitions > SparkR:::getNumPartitions(x)) { + if (shuffle || numPartitions > SparkR:::getNumPartitionsRDD(x)) { func <- function(partIndex, part) { set.seed(partIndex) # partIndex as seed start <- as.integer(base::sample(numPartitions, 1) - 1) @@ -1143,7 +1143,7 @@ setMethod("saveAsTextFile", #' @noRd setMethod("sortBy", signature(x = "RDD", func = "function"), - function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, func, ascending = TRUE, numPartitions = SparkR:::getNumPartitionsRDD(x)) { values(sortByKey(keyBy(x, func), ascending, numPartitions)) }) @@ -1175,7 +1175,7 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { resList <- list() index <- -1 jrdd <- getJRDD(newRdd) - numPartitions <- getNumPartitions(newRdd) + numPartitions <- getNumPartitionsRDD(newRdd) serializedModeRDD <- getSerializedMode(newRdd) while (TRUE) { @@ -1407,7 +1407,7 @@ setMethod("setName", setMethod("zipWithUniqueId", signature(x = "RDD"), function(x) { - n <- getNumPartitions(x) + n <- getNumPartitionsRDD(x) partitionFunc <- function(partIndex, part) { mapply( @@ -1450,7 +1450,7 @@ setMethod("zipWithUniqueId", setMethod("zipWithIndex", signature(x = "RDD"), function(x) { - n <- getNumPartitions(x) + n <- getNumPartitionsRDD(x) if (n > 1) { nums <- collectRDD(lapplyPartition(x, function(part) { @@ -1566,8 +1566,8 @@ setMethod("unionRDD", setMethod("zipRDD", signature(x = "RDD", other = "RDD"), function(x, other) { - n1 <- getNumPartitions(x) - n2 <- getNumPartitions(other) + n1 <- getNumPartitionsRDD(x) + n2 <- getNumPartitionsRDD(other) if (n1 != n2) { stop("Can only zip RDDs which have the same number of partitions.") } @@ -1637,7 +1637,7 @@ setMethod("cartesian", #' @noRd setMethod("subtract", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitionsRDD(x)) { mapFunction <- function(e) { list(e, NA) } rdd1 <- map(x, mapFunction) rdd2 <- map(other, mapFunction) @@ -1671,7 +1671,7 @@ setMethod("subtract", #' @noRd setMethod("intersection", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitionsRDD(x)) { rdd1 <- map(x, function(v) { list(v, NA) }) rdd2 <- map(other, function(v) { list(v, NA) }) @@ -1714,7 +1714,7 @@ setMethod("zipPartitions", if (length(rrdds) == 1) { return(rrdds[[1]]) } - nPart <- sapply(rrdds, getNumPartitions) + nPart <- sapply(rrdds, getNumPartitionsRDD) if (length(unique(nPart)) != 1) { stop("Can only zipPartitions RDDs which have the same number of partitions.") } diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 38d83c6e5c52b..f5c3a749fe0a1 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -184,8 +184,11 @@ getDefaultSqlSource <- function() { #' #' Converts R data.frame or list into SparkDataFrame. #' -#' @param data an RDD or list or data.frame. +#' @param data a list or data.frame. #' @param schema a list of column names or named list (StructType), optional. +#' @param samplingRatio Currently not used. +#' @param numPartitions the number of partitions of the SparkDataFrame. Defaults to 1, this is +#' limited by length of the list or number of rows of the data.frame #' @return A SparkDataFrame. #' @rdname createDataFrame #' @export @@ -195,12 +198,14 @@ getDefaultSqlSource <- function() { #' df1 <- as.DataFrame(iris) #' df2 <- as.DataFrame(list(3,4,5,6)) #' df3 <- createDataFrame(iris) +#' df4 <- createDataFrame(cars, numPartitions = 2) #' } #' @name createDataFrame #' @method createDataFrame default #' @note createDataFrame since 1.4.0 # TODO(davies): support sampling and infer type from NA -createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { +createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, + numPartitions = NULL) { sparkSession <- getSparkSession() if (is.data.frame(data)) { @@ -233,7 +238,11 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { if (is.list(data)) { sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) - rdd <- parallelize(sc, data) + if (!is.null(numPartitions)) { + rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions)) + } else { + rdd <- parallelize(sc, data, numSlices = 1) + } } else if (inherits(data, "RDD")) { rdd <- data } else { @@ -283,14 +292,13 @@ createDataFrame <- function(x, ...) { dispatchFunc("createDataFrame(data, schema = NULL)", x, ...) } -#' @param samplingRatio Currently not used. #' @rdname createDataFrame #' @aliases createDataFrame #' @export #' @method as.DataFrame default #' @note as.DataFrame since 1.6.0 -as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { - createDataFrame(data, schema) +as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) { + createDataFrame(data, schema, samplingRatio, numPartitions) } #' @param ... additional argument(s). @@ -324,8 +332,10 @@ setMethod("toDF", signature(x = "RDD"), #' Create a SparkDataFrame from a JSON file. #' -#' Loads a JSON file (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ), returning the result as a SparkDataFrame +#' Loads a JSON file, returning the result as a SparkDataFrame +#' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} +#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. @@ -338,6 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' df <- read.json(path, wholeFile = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -533,12 +544,15 @@ sql <- function(x, ...) { dispatchFunc("sql(sqlQuery)", x, ...) } -#' Create a SparkDataFrame from a SparkSQL Table +#' Create a SparkDataFrame from a SparkSQL table or view #' -#' Returns the specified Table as a SparkDataFrame. The Table must have already been registered -#' in the SparkSession. +#' Returns the specified table or view as a SparkDataFrame. The table or view must already exist or +#' have already been registered in the SparkSession. #' -#' @param tableName The SparkSQL Table to convert to a SparkDataFrame. +#' @param tableName the qualified or unqualified name that designates a table or view. If a database +#' is specified, it identifies the table/view from the database. +#' Otherwise, it first attempts to find a temporary view with the given name +#' and then match the table/view from the current database. #' @return SparkDataFrame #' @rdname tableToDF #' @name tableToDF @@ -558,199 +572,6 @@ tableToDF <- function(tableName) { dataFrame(sdf) } -#' Tables -#' -#' Returns a SparkDataFrame containing names of tables in the given database. -#' -#' @param databaseName name of the database -#' @return a SparkDataFrame -#' @rdname tables -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tables("hive") -#' } -#' @name tables -#' @method tables default -#' @note tables since 1.4.0 -tables.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) - dataFrame(jdf) -} - -tables <- function(x, ...) { - dispatchFunc("tables(databaseName = NULL)", x, ...) -} - -#' Table Names -#' -#' Returns the names of tables in the given database as an array. -#' -#' @param databaseName name of the database -#' @return a list of table names -#' @rdname tableNames -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' tableNames("hive") -#' } -#' @name tableNames -#' @method tableNames default -#' @note tableNames since 1.4.0 -tableNames.default <- function(databaseName = NULL) { - sparkSession <- getSparkSession() - callJStatic("org.apache.spark.sql.api.r.SQLUtils", - "getTableNames", - sparkSession, - databaseName) -} - -tableNames <- function(x, ...) { - dispatchFunc("tableNames(databaseName = NULL)", x, ...) -} - -#' Cache Table -#' -#' Caches the specified table in-memory. -#' -#' @param tableName The name of the table being cached -#' @return SparkDataFrame -#' @rdname cacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' cacheTable("table") -#' } -#' @name cacheTable -#' @method cacheTable default -#' @note cacheTable since 1.4.0 -cacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "cacheTable", tableName) -} - -cacheTable <- function(x, ...) { - dispatchFunc("cacheTable(tableName)", x, ...) -} - -#' Uncache Table -#' -#' Removes the specified table from the in-memory cache. -#' -#' @param tableName The name of the table being uncached -#' @return SparkDataFrame -#' @rdname uncacheTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' path <- "path/to/file.json" -#' df <- read.json(path) -#' createOrReplaceTempView(df, "table") -#' uncacheTable("table") -#' } -#' @name uncacheTable -#' @method uncacheTable default -#' @note uncacheTable since 1.4.0 -uncacheTable.default <- function(tableName) { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "uncacheTable", tableName) -} - -uncacheTable <- function(x, ...) { - dispatchFunc("uncacheTable(tableName)", x, ...) -} - -#' Clear Cache -#' -#' Removes all cached tables from the in-memory cache. -#' -#' @rdname clearCache -#' @export -#' @examples -#' \dontrun{ -#' clearCache() -#' } -#' @name clearCache -#' @method clearCache default -#' @note clearCache since 1.4.0 -clearCache.default <- function() { - sparkSession <- getSparkSession() - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "clearCache") -} - -clearCache <- function() { - dispatchFunc("clearCache()") -} - -#' (Deprecated) Drop Temporary Table -#' -#' Drops the temporary table with the given table name in the catalog. -#' If the table has been cached/persisted before, it's also unpersisted. -#' -#' @param tableName The name of the SparkSQL table to be dropped. -#' @seealso \link{dropTempView} -#' @rdname dropTempTable-deprecated -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempTable("table") -#' } -#' @name dropTempTable -#' @method dropTempTable default -#' @note dropTempTable since 1.4.0 -dropTempTable.default <- function(tableName) { - if (class(tableName) != "character") { - stop("tableName must be a string.") - } - dropTempView(tableName) -} - -dropTempTable <- function(x, ...) { - .Deprecated("dropTempView") - dispatchFunc("dropTempView(viewName)", x, ...) -} - -#' Drops the temporary view with the given view name in the catalog. -#' -#' Drops the temporary view with the given view name in the catalog. -#' If the view has been cached before, then it will also be uncached. -#' -#' @param viewName the name of the view to be dropped. -#' @rdname dropTempView -#' @name dropTempView -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' df <- read.df(path, "parquet") -#' createOrReplaceTempView(df, "table") -#' dropTempView("table") -#' } -#' @note since 2.0.0 - -dropTempView <- function(viewName) { - sparkSession <- getSparkSession() - if (class(viewName) != "character") { - stop("viewName must be a string.") - } - catalog <- callJMethod(sparkSession, "catalog") - callJMethod(catalog, "dropTempView", viewName) -} - #' Load a SparkDataFrame #' #' Returns the dataset in a data source as a SparkDataFrame @@ -769,6 +590,7 @@ dropTempView <- function(viewName) { #' @return SparkDataFrame #' @rdname read.df #' @name read.df +#' @seealso \link{read.json} #' @export #' @examples #'\dontrun{ @@ -776,7 +598,7 @@ dropTempView <- function(viewName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df @@ -784,7 +606,7 @@ dropTempView <- function(viewName) { #' @note read.df since 1.4.0 read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) { if (!is.null(path) && !is.character(path)) { - stop("path should be charactor, NULL or omitted.") + stop("path should be character, NULL or omitted.") } if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the datasource specified ", @@ -828,45 +650,6 @@ loadDF <- function(x = NULL, ...) { dispatchFunc("loadDF(path = NULL, source = NULL, schema = NULL, ...)", x, ...) } -#' Create an external table -#' -#' Creates an external table based on the dataset in a data source, -#' Returns a SparkDataFrame associated with the external table. -#' -#' The data source is specified by the \code{source} and a set of options(...). -#' If \code{source} is not specified, the default data source configured by -#' "spark.sql.sources.default" will be used. -#' -#' @param tableName a name of the table. -#' @param path the path of files to load. -#' @param source the name of external data source. -#' @param ... additional argument(s) passed to the method. -#' @return A SparkDataFrame. -#' @rdname createExternalTable -#' @export -#' @examples -#'\dontrun{ -#' sparkR.session() -#' df <- createExternalTable("myjson", path="path/to/json", source="json") -#' } -#' @name createExternalTable -#' @method createExternalTable default -#' @note createExternalTable since 1.4.0 -createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { - sparkSession <- getSparkSession() - options <- varargsToStrEnv(...) - if (!is.null(path)) { - options[["path"]] <- path - } - catalog <- callJMethod(sparkSession, "catalog") - sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) - dataFrame(sdf) -} - -createExternalTable <- function(x, ...) { - dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) -} - #' Create a SparkDataFrame representing the database table accessible via JDBC URL #' #' Additional JDBC database connection properties can be set (...) @@ -924,3 +707,53 @@ read.jdbc <- function(url, tableName, } dataFrame(sdf) } + +#' Load a streaming SparkDataFrame +#' +#' Returns the dataset in a data source as a SparkDataFrame +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param source The name of external data source +#' @param schema The data schema defined in structType, this is required for file-based streaming +#' data source +#' @param ... additional external data source specific named options, for instance \code{path} for +#' file-based streaming data source +#' @return SparkDataFrame +#' @rdname read.stream +#' @name read.stream +#' @seealso \link{write.stream} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- read.stream("socket", host = "localhost", port = 9999) +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' } +#' @name read.stream +#' @note read.stream since 2.2.0 +#' @note experimental +read.stream <- function(source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + if (!is.null(source) && !is.character(source)) { + stop("source should be character, NULL or omitted. It is the data source specified ", + "in 'spark.sql.sources.default' configuration by default.") + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + options <- varargsToStrEnv(...) + read <- callJMethod(sparkSession, "readStream") + read <- callJMethod(read, "format", source) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + read <- callJMethod(read, "schema", schema$jobj) + } + read <- callJMethod(read, "options", options) + sdf <- handledCallJMethod(read, "load") + dataFrame(callJMethod(sdf, "toDF")) +} diff --git a/R/pkg/R/catalog.R b/R/pkg/R/catalog.R new file mode 100644 index 0000000000000..e59a7024333ac --- /dev/null +++ b/R/pkg/R/catalog.R @@ -0,0 +1,526 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# catalog.R: SparkSession catalog functions + +#' (Deprecated) Create an external table +#' +#' Creates an external table based on the dataset in a data source, +#' Returns a SparkDataFrame associated with the external table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. +#' +#' @param tableName a name of the table. +#' @param path the path of files to load. +#' @param source the name of external data source. +#' @param schema the schema of the data required for some data sources. +#' @param ... additional argument(s) passed to the method. +#' @return A SparkDataFrame. +#' @rdname createExternalTable-deprecated +#' @seealso \link{createTable} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createExternalTable("myjson", path="path/to/json", source="json", schema) +#' } +#' @name createExternalTable +#' @method createExternalTable default +#' @note createExternalTable since 1.4.0 +createExternalTable.default <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + .Deprecated("createTable", old = "createExternalTable") + createTable(tableName, path, source, schema, ...) +} + +createExternalTable <- function(x, ...) { + dispatchFunc("createExternalTable(tableName, path = NULL, source = NULL, ...)", x, ...) +} + +#' Creates a table based on the dataset in a data source +#' +#' Creates a table based on the dataset in a data source. Returns a SparkDataFrame associated with +#' the table. +#' +#' The data source is specified by the \code{source} and a set of options(...). +#' If \code{source} is not specified, the default data source configured by +#' "spark.sql.sources.default" will be used. When a \code{path} is specified, an external table is +#' created from the data at the given path. Otherwise a managed table is created. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @param path (optional) the path of files to load. +#' @param source (optional) the name of the data source. +#' @param schema (optional) the schema of the data required for some data sources. +#' @param ... additional named parameters as options for the data source. +#' @return A SparkDataFrame. +#' @rdname createTable +#' @seealso \link{createExternalTable} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' df <- createTable("myjson", path="path/to/json", source="json", schema) +#' +#' createTable("people", source = "json", schema = schema) +#' insertInto(df, "people") +#' } +#' @name createTable +#' @note createTable since 2.2.0 +createTable <- function(tableName, path = NULL, source = NULL, schema = NULL, ...) { + sparkSession <- getSparkSession() + options <- varargsToStrEnv(...) + if (!is.null(path)) { + options[["path"]] <- path + } + if (is.null(source)) { + source <- getDefaultSqlSource() + } + catalog <- callJMethod(sparkSession, "catalog") + if (is.null(schema)) { + sdf <- callJMethod(catalog, "createTable", tableName, source, options) + } else if (class(schema) == "structType") { + sdf <- callJMethod(catalog, "createTable", tableName, source, schema$jobj, options) + } else { + stop("schema must be a structType.") + } + dataFrame(sdf) +} + +#' Cache Table +#' +#' Caches the specified table in-memory. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @return SparkDataFrame +#' @rdname cacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' cacheTable("table") +#' } +#' @name cacheTable +#' @method cacheTable default +#' @note cacheTable since 1.4.0 +cacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "cacheTable", tableName)) +} + +cacheTable <- function(x, ...) { + dispatchFunc("cacheTable(tableName)", x, ...) +} + +#' Uncache Table +#' +#' Removes the specified table from the in-memory cache. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @return SparkDataFrame +#' @rdname uncacheTable +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' path <- "path/to/file.json" +#' df <- read.json(path) +#' createOrReplaceTempView(df, "table") +#' uncacheTable("table") +#' } +#' @name uncacheTable +#' @method uncacheTable default +#' @note uncacheTable since 1.4.0 +uncacheTable.default <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "uncacheTable", tableName)) +} + +uncacheTable <- function(x, ...) { + dispatchFunc("uncacheTable(tableName)", x, ...) +} + +#' Clear Cache +#' +#' Removes all cached tables from the in-memory cache. +#' +#' @rdname clearCache +#' @export +#' @examples +#' \dontrun{ +#' clearCache() +#' } +#' @name clearCache +#' @method clearCache default +#' @note clearCache since 1.4.0 +clearCache.default <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(callJMethod(catalog, "clearCache")) +} + +clearCache <- function() { + dispatchFunc("clearCache()") +} + +#' (Deprecated) Drop Temporary Table +#' +#' Drops the temporary table with the given table name in the catalog. +#' If the table has been cached/persisted before, it's also unpersisted. +#' +#' @param tableName The name of the SparkSQL table to be dropped. +#' @seealso \link{dropTempView} +#' @rdname dropTempTable-deprecated +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempTable("table") +#' } +#' @name dropTempTable +#' @method dropTempTable default +#' @note dropTempTable since 1.4.0 +dropTempTable.default <- function(tableName) { + .Deprecated("dropTempView", old = "dropTempTable") + if (class(tableName) != "character") { + stop("tableName must be a string.") + } + dropTempView(tableName) +} + +dropTempTable <- function(x, ...) { + dispatchFunc("dropTempView(viewName)", x, ...) +} + +#' Drops the temporary view with the given view name in the catalog. +#' +#' Drops the temporary view with the given view name in the catalog. +#' If the view has been cached before, then it will also be uncached. +#' +#' @param viewName the name of the temporary view to be dropped. +#' @return TRUE if the view is dropped successfully, FALSE otherwise. +#' @rdname dropTempView +#' @name dropTempView +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' df <- read.df(path, "parquet") +#' createOrReplaceTempView(df, "table") +#' dropTempView("table") +#' } +#' @note since 2.0.0 +dropTempView <- function(viewName) { + sparkSession <- getSparkSession() + if (class(viewName) != "character") { + stop("viewName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", viewName) +} + +#' Tables +#' +#' Returns a SparkDataFrame containing names of tables in the given database. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame +#' @rdname tables +#' @seealso \link{listTables} +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tables("hive") +#' } +#' @name tables +#' @method tables default +#' @note tables since 1.4.0 +tables.default <- function(databaseName = NULL) { + # rename column to match previous output schema + withColumnRenamed(listTables(databaseName), "name", "tableName") +} + +tables <- function(x, ...) { + dispatchFunc("tables(databaseName = NULL)", x, ...) +} + +#' Table Names +#' +#' Returns the names of tables in the given database as an array. +#' +#' @param databaseName (optional) name of the database +#' @return a list of table names +#' @rdname tableNames +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' tableNames("hive") +#' } +#' @name tableNames +#' @method tableNames default +#' @note tableNames since 1.4.0 +tableNames.default <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) +} + +tableNames <- function(x, ...) { + dispatchFunc("tableNames(databaseName = NULL)", x, ...) +} + +#' Returns the current default database +#' +#' Returns the current default database. +#' +#' @return name of the current default database. +#' @rdname currentDatabase +#' @name currentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' currentDatabase() +#' } +#' @note since 2.2.0 +currentDatabase <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "currentDatabase") +} + +#' Sets the current default database +#' +#' Sets the current default database. +#' +#' @param databaseName name of the database +#' @rdname setCurrentDatabase +#' @name setCurrentDatabase +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' setCurrentDatabase("default") +#' } +#' @note since 2.2.0 +setCurrentDatabase <- function(databaseName) { + sparkSession <- getSparkSession() + if (class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "setCurrentDatabase", databaseName)) +} + +#' Returns a list of databases available +#' +#' Returns a list of databases available. +#' +#' @return a SparkDataFrame of the list of databases. +#' @rdname listDatabases +#' @name listDatabases +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listDatabases() +#' } +#' @note since 2.2.0 +listDatabases <- function() { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + dataFrame(callJMethod(callJMethod(catalog, "listDatabases"), "toDF")) +} + +#' Returns a list of tables or views in the specified database +#' +#' Returns a list of tables or views in the specified database. +#' This includes all temporary views. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of tables. +#' @rdname listTables +#' @name listTables +#' @seealso \link{tables} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listTables() +#' listTables("default") +#' } +#' @note since 2.2.0 +listTables <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listTables") + } else { + handledCallJMethod(catalog, "listTables", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of columns for the given table/view in the specified database +#' +#' Returns a list of columns for the given table/view in the specified database. +#' +#' @param tableName the qualified or unqualified name that designates a table/view. If no database +#' identifier is provided, it refers to a table/view in the current database. +#' If \code{databaseName} parameter is specified, this must be an unqualified name. +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of column descriptions. +#' @rdname listColumns +#' @name listColumns +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listColumns("mytable") +#' } +#' @note since 2.2.0 +listColumns <- function(tableName, databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + handledCallJMethod(catalog, "listColumns", tableName) + } else { + handledCallJMethod(catalog, "listColumns", databaseName, tableName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Returns a list of functions registered in the specified database +#' +#' Returns a list of functions registered in the specified database. +#' This includes all temporary functions. +#' +#' @param databaseName (optional) name of the database +#' @return a SparkDataFrame of the list of function descriptions. +#' @rdname listFunctions +#' @name listFunctions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' listFunctions() +#' } +#' @note since 2.2.0 +listFunctions <- function(databaseName = NULL) { + sparkSession <- getSparkSession() + if (!is.null(databaseName) && class(databaseName) != "character") { + stop("databaseName must be a string.") + } + catalog <- callJMethod(sparkSession, "catalog") + jdst <- if (is.null(databaseName)) { + callJMethod(catalog, "listFunctions") + } else { + handledCallJMethod(catalog, "listFunctions", databaseName) + } + dataFrame(callJMethod(jdst, "toDF")) +} + +#' Recovers all the partitions in the directory of a table and update the catalog +#' +#' Recovers all the partitions in the directory of a table and update the catalog. The name should +#' reference a partitioned table, and not a view. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @rdname recoverPartitions +#' @name recoverPartitions +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' recoverPartitions("myTable") +#' } +#' @note since 2.2.0 +recoverPartitions <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "recoverPartitions", tableName)) +} + +#' Invalidates and refreshes all the cached data and metadata of the given table +#' +#' Invalidates and refreshes all the cached data and metadata of the given table. For performance +#' reasons, Spark SQL or the external data source library it uses might cache certain metadata about +#' a table, such as the location of blocks. When those change outside of Spark SQL, users should +#' call this function to invalidate the cache. +#' +#' If this table is cached as an InMemoryRelation, drop the original cached version and make the +#' new version cached lazily. +#' +#' @param tableName the qualified or unqualified name that designates a table. If no database +#' identifier is provided, it refers to a table in the current database. +#' @rdname refreshTable +#' @name refreshTable +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshTable("myTable") +#' } +#' @note since 2.2.0 +refreshTable <- function(tableName) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshTable", tableName)) +} + +#' Invalidates and refreshes all the cached data and metadata for SparkDataFrame containing path +#' +#' Invalidates and refreshes all the cached data (and the associated metadata) for any +#' SparkDataFrame that contains the given data source path. Path matching is by prefix, i.e. "/" +#' would invalidate everything that is cached. +#' +#' @param path the path of the data source. +#' @rdname refreshByPath +#' @name refreshByPath +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' refreshByPath("/path") +#' } +#' @note since 2.2.0 +refreshByPath <- function(path) { + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + invisible(handledCallJMethod(catalog, "refreshByPath", path)) +} diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 539d91b0f8797..574078012adad 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -67,8 +67,7 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or", #, "!" = "unary_$bang" - "^" = "pow" + "&" = "and", "|" = "or", "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "getField", "getItem", "contains") @@ -131,19 +130,19 @@ createMethods <- function() { createMethods() -#' alias -#' -#' Set a new name for a column -#' -#' @param object Column to rename -#' @param data new name to use -#' #' @rdname alias #' @name alias #' @aliases alias,Column-method #' @family colum_func #' @export -#' @note alias since 1.4.0 +#' @examples \dontrun{ +#' df <- createDataFrame(iris) +#' +#' head(select( +#' df, alias(df$Sepal_Length, "slength"), alias(df$Petal_Length, "plength") +#' )) +#' } +#' @note alias(Column) since 1.4.0 setMethod("alias", signature(object = "Column"), function(object, data) { @@ -302,3 +301,55 @@ setMethod("otherwise", jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) + +#' \%<=>\% +#' +#' Equality test that is safe for null values. +#' +#' Can be used, unlike standard equality operator, to perform null-safe joins. +#' Equivalent to Scala \code{Column.<=>} and \code{Column.eqNullSafe}. +#' +#' @param x a Column +#' @param value a value to compare +#' @rdname eq_null_safe +#' @name %<=>% +#' @aliases %<=>%,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df1 <- createDataFrame(data.frame( +#' x = c(1, NA, 3, NA), y = c(2, 6, 3, NA) +#' )) +#' +#' head(select(df1, df1$x == df1$y, df1$x %<=>% df1$y)) +#' +#' df2 <- createDataFrame(data.frame(y = c(3, NA))) +#' count(join(df1, df2, df1$y == df2$y)) +#' +#' count(join(df1, df2, df1$y %<=>% df2$y)) +#' } +#' @note \%<=>\% since 2.3.0 +setMethod("%<=>%", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- if (class(value) == "Column") { value@jc } else { value } + jc <- callJMethod(x@jc, "eqNullSafe", value) + column(jc) + }) + +#' ! +#' +#' Inversion of boolean expression. +#' +#' @rdname not +#' @name not +#' @aliases !,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) +#' +#' head(select(df, !column("x") > 0)) +#' } +#' @note ! since 2.3.0 +setMethod("!", signature(x = "Column"), function(x) not(x)) diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 438d77a388f0e..8349b57a30a93 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -87,10 +87,20 @@ objectFile <- function(sc, path, minPartitions = NULL) { #' in the list are split into \code{numSlices} slices and distributed to nodes #' in the cluster. #' -#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function -#' will write it to disk and send the file name to JVM. Also to make sure each slice is not +#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function +#' will write it to disk and send the file name to JVM. Also to make sure each slice is not #' larger than that limit, number of slices may be increased. #' +#' In 2.2.0 we are changing how the numSlices are used/computed to handle +#' 1 < (length(coll) / numSlices) << length(coll) better, and to get the exact number of slices. +#' This change affects both createDataFrame and spark.lapply. +#' In the specific one case that it is used to convert R native object into SparkDataFrame, it has +#' always been kept at the default of 1. In the case the object is large, we are explicitly setting +#' the parallism to numSlices (which is still 1). +#' +#' Specifically, we are changing to split positions to match the calculation in positions() of +#' ParallelCollectionRDD in Spark. +#' #' @param sc SparkContext to use #' @param coll collection to parallelize #' @param numSlices number of partitions to create in the RDD @@ -107,6 +117,8 @@ parallelize <- function(sc, coll, numSlices = 1) { # TODO: bound/safeguard numSlices # TODO: unit tests for if the split works for all primitives # TODO: support matrix, data frame, etc + + # Note, for data.frame, createDataFrame turns it into a list before it calls here. # nolint start # suppress lintr warning: Place a space before left parenthesis, except in a function call. if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) { @@ -128,12 +140,29 @@ parallelize <- function(sc, coll, numSlices = 1) { objectSize <- object.size(coll) # For large objects we make sure the size of each slice is also smaller than sizeLimit - numSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) - if (numSlices > length(coll)) - numSlices <- length(coll) + numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit)) + if (numSerializedSlices > length(coll)) + numSerializedSlices <- length(coll) + + # Generate the slice ids to put each row + # For instance, for numSerializedSlices of 22, length of 50 + # [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22 + # [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47 + # Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced. + # We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD + splits <- if (numSerializedSlices > 0) { + unlist(lapply(0: (numSerializedSlices - 1), function(x) { + # nolint start + start <- trunc((x * length(coll)) / numSerializedSlices) + end <- trunc(((x + 1) * length(coll)) / numSerializedSlices) + # nolint end + rep(start, end - start) + })) + } else { + 1 + } - sliceLen <- ceiling(length(coll) / numSlices) - slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)]) + slices <- split(coll, splits) # Serialize each slice: obtain a list of raws, or a list of lists (slices) of # 2-tuples of raws @@ -229,7 +258,7 @@ includePackage <- function(sc, pkg) { #' #' # Large Matrix object that we want to broadcast #' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000)) -#' randomMatBr <- broadcast(sc, randomMat) +#' randomMatBr <- broadcastRDD(sc, randomMat) #' #' # Use the broadcast variable inside the function #' useBroadcast <- function(x) { @@ -237,7 +266,7 @@ includePackage <- function(sc, pkg) { #' } #' sumRDD <- lapply(rdd, useBroadcast) #'} -broadcast <- function(sc, object) { +broadcastRDD <- function(sc, object) { objName <- as.character(substitute(object)) serializedObj <- serialize(object, connection = NULL) @@ -262,7 +291,7 @@ broadcast <- function(sc, object) { #' rdd <- parallelize(sc, 1:2, 2L) #' checkpoint(rdd) #'} -setCheckpointDir <- function(sc, dirName) { +setCheckpointDirSC <- function(sc, dirName) { invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName)))) } @@ -301,7 +330,13 @@ spark.addFile <- function(path, recursive = FALSE) { #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 spark.getSparkFilesRootDirectory <- function() { - callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") + } else { + # Running on worker. + Sys.getenv("SPARKR_SPARKFILES_ROOT_DIR") + } } #' Get the absolute path of a file added through spark.addFile. @@ -316,7 +351,13 @@ spark.getSparkFilesRootDirectory <- function() { #'} #' @note spark.getSparkFiles since 2.1.0 spark.getSparkFiles <- function(fileName) { - callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { + # Running on driver. + callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName)) + } else { + # Running on worker. + file.path(spark.getSparkFilesRootDirectory(), as.character(fileName)) + } } #' Run a function over a list of elements, distributing the computations with Spark @@ -379,5 +420,24 @@ spark.lapply <- function(list, func) { #' @note setLogLevel since 2.0.0 setLogLevel <- function(level) { sc <- getSparkContext() - callJMethod(sc, "setLogLevel", level) + invisible(callJMethod(sc, "setLogLevel", level)) +} + +#' Set checkpoint directory +#' +#' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be +#' a HDFS path if running on a cluster. +#' +#' @rdname setCheckpointDir +#' @param directory Directory path to checkpoint to +#' @seealso \link{checkpoint} +#' @export +#' @examples +#'\dontrun{ +#' setCheckpointDir("/checkpoint") +#'} +#' @note setCheckpointDir since 2.2.0 +setCheckpointDir <- function(directory) { + sc <- getSparkContext() + invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory)))) } diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 4d94b4cd05d44..a6c2dea0ff2a7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -24,7 +24,7 @@ NULL #' If the parameter is a \linkS4class{Column}, it is returned unchanged. #' #' @param x a literal value or a Column. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname lit #' @name lit #' @export @@ -52,7 +52,7 @@ setMethod("lit", signature("ANY"), #' #' @rdname abs #' @name abs -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @examples \dontrun{abs(df$c)} #' @aliases abs,Column-method @@ -73,7 +73,7 @@ setMethod("abs", #' #' @rdname acos #' @name acos -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{acos(df$c)} #' @aliases acos,Column-method @@ -113,7 +113,7 @@ setMethod("approxCountDistinct", #' #' @rdname ascii #' @name ascii -#' @family string_funcs +#' @family string functions #' @export #' @aliases ascii,Column-method #' @examples \dontrun{\dontrun{ascii(df$c)}} @@ -134,7 +134,7 @@ setMethod("ascii", #' #' @rdname asin #' @name asin -#' @family math_funcs +#' @family math functions #' @export #' @aliases asin,Column-method #' @examples \dontrun{asin(df$c)} @@ -154,7 +154,7 @@ setMethod("asin", #' #' @rdname atan #' @name atan -#' @family math_funcs +#' @family math functions #' @export #' @aliases atan,Column-method #' @examples \dontrun{atan(df$c)} @@ -172,7 +172,7 @@ setMethod("atan", #' #' @rdname avg #' @name avg -#' @family agg_funcs +#' @family aggregate functions #' @export #' @aliases avg,Column-method #' @examples \dontrun{avg(df$c)} @@ -193,7 +193,7 @@ setMethod("avg", #' #' @rdname base64 #' @name base64 -#' @family string_funcs +#' @family string functions #' @export #' @aliases base64,Column-method #' @examples \dontrun{base64(df$c)} @@ -214,7 +214,7 @@ setMethod("base64", #' #' @rdname bin #' @name bin -#' @family math_funcs +#' @family math functions #' @export #' @aliases bin,Column-method #' @examples \dontrun{bin(df$c)} @@ -234,7 +234,7 @@ setMethod("bin", #' #' @rdname bitwiseNOT #' @name bitwiseNOT -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @aliases bitwiseNOT,Column-method #' @examples \dontrun{bitwiseNOT(df$c)} @@ -254,7 +254,7 @@ setMethod("bitwiseNOT", #' #' @rdname cbrt #' @name cbrt -#' @family math_funcs +#' @family math functions #' @export #' @aliases cbrt,Column-method #' @examples \dontrun{cbrt(df$c)} @@ -274,7 +274,7 @@ setMethod("cbrt", #' #' @rdname ceil #' @name ceil -#' @family math_funcs +#' @family math functions #' @export #' @aliases ceil,Column-method #' @examples \dontrun{ceil(df$c)} @@ -286,6 +286,28 @@ setMethod("ceil", column(jc) }) +#' Returns the first column that is not NA +#' +#' Returns the first column that is not NA, or NA if all inputs are. +#' +#' @rdname coalesce +#' @name coalesce +#' @family non-aggregate functions +#' @export +#' @aliases coalesce,Column-method +#' @examples \dontrun{coalesce(df$c, df$d, df$e)} +#' @note coalesce(Column) since 2.1.1 +setMethod("coalesce", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "coalesce", jcols) + column(jc) + }) + #' Though scala functions has "col" function, we don't expose it in SparkR #' because we don't want to conflict with the "col" function in the R base #' package and we also have "column" function exported which is an alias of "col". @@ -297,15 +319,15 @@ col <- function(x) { #' Returns a Column based on the given column name #' #' Returns a Column based on the given column name. -# +#' #' @param x Character column name. #' #' @rdname column #' @name column -#' @family normal_funcs +#' @family non-aggregate functions #' @export #' @aliases column,character-method -#' @examples \dontrun{column(df)} +#' @examples \dontrun{column("name")} #' @note column since 1.6.0 setMethod("column", signature(x = "character"), @@ -320,7 +342,7 @@ setMethod("column", #' #' @rdname corr #' @name corr -#' @family math_funcs +#' @family math functions #' @export #' @aliases corr,Column-method #' @examples \dontrun{corr(df$c, df$d)} @@ -338,7 +360,7 @@ setMethod("corr", signature(x = "Column"), #' #' @rdname cov #' @name cov -#' @family math_funcs +#' @family math functions #' @export #' @aliases cov,characterOrColumn-method #' @examples @@ -382,7 +404,7 @@ setMethod("covar_samp", signature(col1 = "characterOrColumn", col2 = "characterO #' #' @rdname covar_pop #' @name covar_pop -#' @family math_funcs +#' @family math functions #' @export #' @aliases covar_pop,characterOrColumn,characterOrColumn-method #' @examples @@ -410,7 +432,7 @@ setMethod("covar_pop", signature(col1 = "characterOrColumn", col2 = "characterOr #' #' @rdname cos #' @name cos -#' @family math_funcs +#' @family math functions #' @aliases cos,Column-method #' @export #' @examples \dontrun{cos(df$c)} @@ -430,7 +452,7 @@ setMethod("cos", #' #' @rdname cosh #' @name cosh -#' @family math_funcs +#' @family math functions #' @aliases cosh,Column-method #' @export #' @examples \dontrun{cosh(df$c)} @@ -449,7 +471,7 @@ setMethod("cosh", #' #' @rdname count #' @name count -#' @family agg_funcs +#' @family aggregate functions #' @aliases count,Column-method #' @export #' @examples \dontrun{count(df$c)} @@ -470,7 +492,7 @@ setMethod("count", #' #' @rdname crc32 #' @name crc32 -#' @family misc_funcs +#' @family misc functions #' @aliases crc32,Column-method #' @export #' @examples \dontrun{crc32(df$c)} @@ -491,7 +513,7 @@ setMethod("crc32", #' #' @rdname hash #' @name hash -#' @family misc_funcs +#' @family misc functions #' @aliases hash,Column-method #' @export #' @examples \dontrun{hash(df$c)} @@ -515,7 +537,7 @@ setMethod("hash", #' #' @rdname dayofmonth #' @name dayofmonth -#' @family datetime_funcs +#' @family date time functions #' @aliases dayofmonth,Column-method #' @export #' @examples \dontrun{dayofmonth(df$c)} @@ -535,7 +557,7 @@ setMethod("dayofmonth", #' #' @rdname dayofyear #' @name dayofyear -#' @family datetime_funcs +#' @family date time functions #' @aliases dayofyear,Column-method #' @export #' @examples \dontrun{dayofyear(df$c)} @@ -557,7 +579,7 @@ setMethod("dayofyear", #' #' @rdname decode #' @name decode -#' @family string_funcs +#' @family string functions #' @aliases decode,Column,character-method #' @export #' @examples \dontrun{decode(df$c, "UTF-8")} @@ -579,7 +601,7 @@ setMethod("decode", #' #' @rdname encode #' @name encode -#' @family string_funcs +#' @family string functions #' @aliases encode,Column,character-method #' @export #' @examples \dontrun{encode(df$c, "UTF-8")} @@ -599,7 +621,7 @@ setMethod("encode", #' #' @rdname exp #' @name exp -#' @family math_funcs +#' @family math functions #' @aliases exp,Column-method #' @export #' @examples \dontrun{exp(df$c)} @@ -620,7 +642,7 @@ setMethod("exp", #' @rdname expm1 #' @name expm1 #' @aliases expm1,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{expm1(df$c)} #' @note expm1 since 1.5.0 @@ -640,7 +662,7 @@ setMethod("expm1", #' @rdname factorial #' @name factorial #' @aliases factorial,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{factorial(df$c)} #' @note factorial since 1.5.0 @@ -664,7 +686,7 @@ setMethod("factorial", #' @rdname first #' @name first #' @aliases first,characterOrColumn-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples #' \dontrun{ @@ -693,7 +715,7 @@ setMethod("first", #' @rdname floor #' @name floor #' @aliases floor,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{floor(df$c)} #' @note floor since 1.5.0 @@ -712,7 +734,7 @@ setMethod("floor", #' #' @rdname hex #' @name hex -#' @family math_funcs +#' @family math functions #' @aliases hex,Column-method #' @export #' @examples \dontrun{hex(df$c)} @@ -733,7 +755,7 @@ setMethod("hex", #' @rdname hour #' @name hour #' @aliases hour,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{hour(df$c)} #' @note hour since 1.5.0 @@ -755,7 +777,7 @@ setMethod("hour", #' #' @rdname initcap #' @name initcap -#' @family string_funcs +#' @family string functions #' @aliases initcap,Column-method #' @export #' @examples \dontrun{initcap(df$c)} @@ -775,7 +797,7 @@ setMethod("initcap", #' #' @rdname is.nan #' @name is.nan -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases is.nan,Column-method #' @export #' @examples @@ -810,7 +832,7 @@ setMethod("isnan", #' @rdname kurtosis #' @name kurtosis #' @aliases kurtosis,Column-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples \dontrun{kurtosis(df$c)} #' @note kurtosis since 1.6.0 @@ -836,7 +858,7 @@ setMethod("kurtosis", #' @rdname last #' @name last #' @aliases last,characterOrColumn-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples #' \dontrun{ @@ -867,7 +889,7 @@ setMethod("last", #' @rdname last_day #' @name last_day #' @aliases last_day,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{last_day(df$c)} #' @note last_day since 1.5.0 @@ -887,7 +909,7 @@ setMethod("last_day", #' @rdname length #' @name length #' @aliases length,Column-method -#' @family string_funcs +#' @family string functions #' @export #' @examples \dontrun{length(df$c)} #' @note length since 1.5.0 @@ -907,7 +929,7 @@ setMethod("length", #' @rdname log #' @name log #' @aliases log,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{log(df$c)} #' @note log since 1.5.0 @@ -926,7 +948,7 @@ setMethod("log", #' #' @rdname log10 #' @name log10 -#' @family math_funcs +#' @family math functions #' @aliases log10,Column-method #' @export #' @examples \dontrun{log10(df$c)} @@ -946,7 +968,7 @@ setMethod("log10", #' #' @rdname log1p #' @name log1p -#' @family math_funcs +#' @family math functions #' @aliases log1p,Column-method #' @export #' @examples \dontrun{log1p(df$c)} @@ -966,7 +988,7 @@ setMethod("log1p", #' #' @rdname log2 #' @name log2 -#' @family math_funcs +#' @family math functions #' @aliases log2,Column-method #' @export #' @examples \dontrun{log2(df$c)} @@ -986,7 +1008,7 @@ setMethod("log2", #' #' @rdname lower #' @name lower -#' @family string_funcs +#' @family string functions #' @aliases lower,Column-method #' @export #' @examples \dontrun{lower(df$c)} @@ -1006,7 +1028,7 @@ setMethod("lower", #' #' @rdname ltrim #' @name ltrim -#' @family string_funcs +#' @family string functions #' @aliases ltrim,Column-method #' @export #' @examples \dontrun{ltrim(df$c)} @@ -1026,7 +1048,7 @@ setMethod("ltrim", #' #' @rdname max #' @name max -#' @family agg_funcs +#' @family aggregate functions #' @aliases max,Column-method #' @export #' @examples \dontrun{max(df$c)} @@ -1047,7 +1069,7 @@ setMethod("max", #' #' @rdname md5 #' @name md5 -#' @family misc_funcs +#' @family misc functions #' @aliases md5,Column-method #' @export #' @examples \dontrun{md5(df$c)} @@ -1068,7 +1090,7 @@ setMethod("md5", #' #' @rdname mean #' @name mean -#' @family agg_funcs +#' @family aggregate functions #' @aliases mean,Column-method #' @export #' @examples \dontrun{mean(df$c)} @@ -1089,7 +1111,7 @@ setMethod("mean", #' @rdname min #' @name min #' @aliases min,Column-method -#' @family agg_funcs +#' @family aggregate functions #' @export #' @examples \dontrun{min(df$c)} #' @note min since 1.5.0 @@ -1109,7 +1131,7 @@ setMethod("min", #' @rdname minute #' @name minute #' @aliases minute,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{minute(df$c)} #' @note minute since 1.5.0 @@ -1138,7 +1160,7 @@ setMethod("minute", #' @rdname monotonically_increasing_id #' @aliases monotonically_increasing_id,missing-method #' @name monotonically_increasing_id -#' @family misc_funcs +#' @family misc functions #' @export #' @examples \dontrun{select(df, monotonically_increasing_id())} setMethod("monotonically_increasing_id", @@ -1157,7 +1179,7 @@ setMethod("monotonically_increasing_id", #' @rdname month #' @name month #' @aliases month,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{month(df$c)} #' @note month since 1.5.0 @@ -1176,7 +1198,7 @@ setMethod("month", #' #' @rdname negate #' @name negate -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases negate,Column-method #' @export #' @examples \dontrun{negate(df$c)} @@ -1196,7 +1218,7 @@ setMethod("negate", #' #' @rdname quarter #' @name quarter -#' @family datetime_funcs +#' @family date time functions #' @aliases quarter,Column-method #' @export #' @examples \dontrun{quarter(df$c)} @@ -1216,7 +1238,7 @@ setMethod("quarter", #' #' @rdname reverse #' @name reverse -#' @family string_funcs +#' @family string functions #' @aliases reverse,Column-method #' @export #' @examples \dontrun{reverse(df$c)} @@ -1237,7 +1259,7 @@ setMethod("reverse", #' #' @rdname rint #' @name rint -#' @family math_funcs +#' @family math functions #' @aliases rint,Column-method #' @export #' @examples \dontrun{rint(df$c)} @@ -1257,7 +1279,7 @@ setMethod("rint", #' #' @rdname round #' @name round -#' @family math_funcs +#' @family math functions #' @aliases round,Column-method #' @export #' @examples \dontrun{round(df$c)} @@ -1283,7 +1305,7 @@ setMethod("round", #' @param ... further arguments to be passed to or from other methods. #' @rdname bround #' @name bround -#' @family math_funcs +#' @family math functions #' @aliases bround,Column-method #' @export #' @examples \dontrun{bround(df$c, 0)} @@ -1304,7 +1326,7 @@ setMethod("bround", #' #' @rdname rtrim #' @name rtrim -#' @family string_funcs +#' @family string functions #' @aliases rtrim,Column-method #' @export #' @examples \dontrun{rtrim(df$c)} @@ -1324,7 +1346,7 @@ setMethod("rtrim", #' @param na.rm currently not used. #' @rdname sd #' @name sd -#' @family agg_funcs +#' @family aggregate functions #' @aliases sd,Column-method #' @seealso \link{stddev_pop}, \link{stddev_samp} #' @export @@ -1350,7 +1372,7 @@ setMethod("sd", #' #' @rdname second #' @name second -#' @family datetime_funcs +#' @family date time functions #' @aliases second,Column-method #' @export #' @examples \dontrun{second(df$c)} @@ -1371,7 +1393,7 @@ setMethod("second", #' #' @rdname sha1 #' @name sha1 -#' @family misc_funcs +#' @family misc functions #' @aliases sha1,Column-method #' @export #' @examples \dontrun{sha1(df$c)} @@ -1392,7 +1414,7 @@ setMethod("sha1", #' @rdname sign #' @name signum #' @aliases signum,Column-method -#' @family math_funcs +#' @family math functions #' @export #' @examples \dontrun{signum(df$c)} #' @note signum since 1.5.0 @@ -1411,7 +1433,7 @@ setMethod("signum", #' #' @rdname sin #' @name sin -#' @family math_funcs +#' @family math functions #' @aliases sin,Column-method #' @export #' @examples \dontrun{sin(df$c)} @@ -1431,7 +1453,7 @@ setMethod("sin", #' #' @rdname sinh #' @name sinh -#' @family math_funcs +#' @family math functions #' @aliases sinh,Column-method #' @export #' @examples \dontrun{sinh(df$c)} @@ -1451,7 +1473,7 @@ setMethod("sinh", #' #' @rdname skewness #' @name skewness -#' @family agg_funcs +#' @family aggregate functions #' @aliases skewness,Column-method #' @export #' @examples \dontrun{skewness(df$c)} @@ -1471,7 +1493,7 @@ setMethod("skewness", #' #' @rdname soundex #' @name soundex -#' @family string_funcs +#' @family string functions #' @aliases soundex,Column-method #' @export #' @examples \dontrun{soundex(df$c)} @@ -1485,7 +1507,7 @@ setMethod("soundex", #' Return the partition ID as a column #' -#' Return the partition ID of the Spark task as a SparkDataFrame column. +#' Return the partition ID as a SparkDataFrame column. #' Note that this is nondeterministic because it depends on data partitioning and #' task scheduling. #' @@ -1524,7 +1546,7 @@ setMethod("stddev", #' #' @rdname stddev_pop #' @name stddev_pop -#' @family agg_funcs +#' @family aggregate functions #' @aliases stddev_pop,Column-method #' @seealso \link{sd}, \link{stddev_samp} #' @export @@ -1545,7 +1567,7 @@ setMethod("stddev_pop", #' #' @rdname stddev_samp #' @name stddev_samp -#' @family agg_funcs +#' @family aggregate functions #' @aliases stddev_samp,Column-method #' @seealso \link{stddev_pop}, \link{sd} #' @export @@ -1567,7 +1589,7 @@ setMethod("stddev_samp", #' #' @rdname struct #' @name struct -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases struct,characterOrColumn-method #' @export #' @examples @@ -1596,7 +1618,7 @@ setMethod("struct", #' #' @rdname sqrt #' @name sqrt -#' @family math_funcs +#' @family math functions #' @aliases sqrt,Column-method #' @export #' @examples \dontrun{sqrt(df$c)} @@ -1616,7 +1638,7 @@ setMethod("sqrt", #' #' @rdname sum #' @name sum -#' @family agg_funcs +#' @family aggregate functions #' @aliases sum,Column-method #' @export #' @examples \dontrun{sum(df$c)} @@ -1636,7 +1658,7 @@ setMethod("sum", #' #' @rdname sumDistinct #' @name sumDistinct -#' @family agg_funcs +#' @family aggregate functions #' @aliases sumDistinct,Column-method #' @export #' @examples \dontrun{sumDistinct(df$c)} @@ -1656,7 +1678,7 @@ setMethod("sumDistinct", #' #' @rdname tan #' @name tan -#' @family math_funcs +#' @family math functions #' @aliases tan,Column-method #' @export #' @examples \dontrun{tan(df$c)} @@ -1676,7 +1698,7 @@ setMethod("tan", #' #' @rdname tanh #' @name tanh -#' @family math_funcs +#' @family math functions #' @aliases tanh,Column-method #' @export #' @examples \dontrun{tanh(df$c)} @@ -1696,7 +1718,7 @@ setMethod("tanh", #' #' @rdname toDegrees #' @name toDegrees -#' @family math_funcs +#' @family math functions #' @aliases toDegrees,Column-method #' @export #' @examples \dontrun{toDegrees(df$c)} @@ -1716,7 +1738,7 @@ setMethod("toDegrees", #' #' @rdname toRadians #' @name toRadians -#' @family math_funcs +#' @family math functions #' @aliases toRadians,Column-method #' @export #' @examples \dontrun{toRadians(df$c)} @@ -1730,24 +1752,124 @@ setMethod("toRadians", #' to_date #' -#' Converts the column into DateType. +#' Converts the column into a DateType. You may optionally specify a format +#' according to the rules in: +#' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. +#' If the string cannot be parsed according to the specified format (or default), +#' the value of the column will be null. +#' By default, it follows casting rules to a DateType if the format is omitted +#' (equivalent to \code{cast(df$x, "date")}). #' -#' @param x Column to compute on. +#' @param x Column to parse. +#' @param format string to use to parse x Column to DateType. (optional) #' #' @rdname to_date #' @name to_date -#' @family datetime_funcs -#' @aliases to_date,Column-method +#' @family date time functions +#' @aliases to_date,Column,missing-method #' @export -#' @examples \dontrun{to_date(df$c)} -#' @note to_date since 1.5.0 +#' @examples +#' \dontrun{ +#' to_date(df$c) +#' to_date(df$c, 'yyyy-MM-dd') +#' } +#' @note to_date(Column) since 1.5.0 setMethod("to_date", - signature(x = "Column"), - function(x) { + signature(x = "Column", format = "missing"), + function(x, format) { jc <- callJStatic("org.apache.spark.sql.functions", "to_date", x@jc) column(jc) }) +#' @rdname to_date +#' @name to_date +#' @family date time functions +#' @aliases to_date,Column,character-method +#' @export +#' @note to_date(Column, character) since 2.2.0 +setMethod("to_date", + signature(x = "Column", format = "character"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_date", x@jc, format) + column(jc) + }) + +#' to_json +#' +#' Converts a column containing a \code{structType} or array of \code{structType} into a Column +#' of JSON string. Resolving the Column can fail if an unsupported type is encountered. +#' +#' @param x Column containing the struct or array of the structs +#' @param ... additional named properties to control how it is converted, accepts the same options +#' as the JSON data source. +#' +#' @family non-aggregate functions +#' @rdname to_json +#' @name to_json +#' @aliases to_json,Column-method +#' @export +#' @examples +#' \dontrun{ +#' # Converts a struct into a JSON object +#' df <- sql("SELECT named_struct('date', cast('2000-01-01' as date)) as d") +#' select(df, to_json(df$d, dateFormat = 'dd/MM/yyyy')) +#' +#' # Converts an array of structs into a JSON array +#' df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") +#' select(df, to_json(df$people)) +#'} +#' @note to_json since 2.2.0 +setMethod("to_json", signature(x = "Column"), + function(x, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", "to_json", x@jc, options) + column(jc) + }) + +#' to_timestamp +#' +#' Converts the column into a TimestampType. You may optionally specify a format +#' according to the rules in: +#' \url{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}. +#' If the string cannot be parsed according to the specified format (or default), +#' the value of the column will be null. +#' By default, it follows casting rules to a TimestampType if the format is omitted +#' (equivalent to \code{cast(df$x, "timestamp")}). +#' +#' @param x Column to parse. +#' @param format string to use to parse x Column to TimestampType. (optional) +#' +#' @rdname to_timestamp +#' @name to_timestamp +#' @family date time functions +#' @aliases to_timestamp,Column,missing-method +#' @export +#' @examples +#' \dontrun{ +#' to_timestamp(df$c) +#' to_timestamp(df$c, 'yyyy-MM-dd') +#' } +#' @note to_timestamp(Column) since 2.2.0 +setMethod("to_timestamp", + signature(x = "Column", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_timestamp", x@jc) + column(jc) + }) + +#' @rdname to_timestamp +#' @name to_timestamp +#' @family date time functions +#' @aliases to_timestamp,Column,character-method +#' @export +#' @note to_timestamp(Column, character) since 2.2.0 +setMethod("to_timestamp", + signature(x = "Column", format = "character"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_timestamp", x@jc, format) + column(jc) + }) + #' trim #' #' Trim the spaces from both ends for the specified string column. @@ -1756,7 +1878,7 @@ setMethod("to_date", #' #' @rdname trim #' @name trim -#' @family string_funcs +#' @family string functions #' @aliases trim,Column-method #' @export #' @examples \dontrun{trim(df$c)} @@ -1777,7 +1899,7 @@ setMethod("trim", #' #' @rdname unbase64 #' @name unbase64 -#' @family string_funcs +#' @family string functions #' @aliases unbase64,Column-method #' @export #' @examples \dontrun{unbase64(df$c)} @@ -1798,7 +1920,7 @@ setMethod("unbase64", #' #' @rdname unhex #' @name unhex -#' @family math_funcs +#' @family math functions #' @aliases unhex,Column-method #' @export #' @examples \dontrun{unhex(df$c)} @@ -1818,7 +1940,7 @@ setMethod("unhex", #' #' @rdname upper #' @name upper -#' @family string_funcs +#' @family string functions #' @aliases upper,Column-method #' @export #' @examples \dontrun{upper(df$c)} @@ -1838,7 +1960,7 @@ setMethod("upper", #' @param y,na.rm,use currently not used. #' @rdname var #' @name var -#' @family agg_funcs +#' @family aggregate functions #' @aliases var,Column-method #' @seealso \link{var_pop}, \link{var_samp} #' @export @@ -1875,7 +1997,7 @@ setMethod("variance", #' #' @rdname var_pop #' @name var_pop -#' @family agg_funcs +#' @family aggregate functions #' @aliases var_pop,Column-method #' @seealso \link{var}, \link{var_samp} #' @export @@ -1897,7 +2019,7 @@ setMethod("var_pop", #' @rdname var_samp #' @name var_samp #' @aliases var_samp,Column-method -#' @family agg_funcs +#' @family aggregate functions #' @seealso \link{var_pop}, \link{var} #' @export #' @examples \dontrun{var_samp(df$c)} @@ -1918,7 +2040,7 @@ setMethod("var_samp", #' @rdname weekofyear #' @name weekofyear #' @aliases weekofyear,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{weekofyear(df$c)} #' @note weekofyear since 1.5.0 @@ -1937,7 +2059,7 @@ setMethod("weekofyear", #' #' @rdname year #' @name year -#' @family datetime_funcs +#' @family date time functions #' @aliases year,Column-method #' @export #' @examples \dontrun{year(df$c)} @@ -1959,7 +2081,7 @@ setMethod("year", #' #' @rdname atan2 #' @name atan2 -#' @family math_funcs +#' @family math functions #' @aliases atan2,Column-method #' @export #' @examples \dontrun{atan2(df$c, x)} @@ -1983,7 +2105,7 @@ setMethod("atan2", signature(y = "Column"), #' @rdname datediff #' @name datediff #' @aliases datediff,Column-method -#' @family datetime_funcs +#' @family date time functions #' @export #' @examples \dontrun{datediff(df$c, x)} #' @note datediff since 1.5.0 @@ -2005,7 +2127,7 @@ setMethod("datediff", signature(y = "Column"), #' #' @rdname hypot #' @name hypot -#' @family math_funcs +#' @family math functions #' @aliases hypot,Column-method #' @export #' @examples \dontrun{hypot(df$c, x)} @@ -2028,7 +2150,7 @@ setMethod("hypot", signature(y = "Column"), #' #' @rdname levenshtein #' @name levenshtein -#' @family string_funcs +#' @family string functions #' @aliases levenshtein,Column-method #' @export #' @examples \dontrun{levenshtein(df$c, x)} @@ -2051,7 +2173,7 @@ setMethod("levenshtein", signature(y = "Column"), #' #' @rdname months_between #' @name months_between -#' @family datetime_funcs +#' @family date time functions #' @aliases months_between,Column-method #' @export #' @examples \dontrun{months_between(df$c, x)} @@ -2075,7 +2197,7 @@ setMethod("months_between", signature(y = "Column"), #' #' @rdname nanvl #' @name nanvl -#' @family normal_funcs +#' @family non-aggregate functions #' @aliases nanvl,Column-method #' @export #' @examples \dontrun{nanvl(df$c, x)} @@ -2099,7 +2221,7 @@ setMethod("nanvl", signature(y = "Column"), #' @rdname pmod #' @name pmod #' @docType methods -#' @family math_funcs +#' @family math functions #' @aliases pmod,Column-method #' @export #' @examples \dontrun{pmod(df$c, x)} @@ -2137,7 +2259,7 @@ setMethod("approxCountDistinct", #' @param x Column to compute on #' @param ... other columns #' -#' @family agg_funcs +#' @family aggregate functions #' @rdname countDistinct #' @name countDistinct #' @aliases countDistinct,Column-method @@ -2165,7 +2287,7 @@ setMethod("countDistinct", #' @param x Column to compute on #' @param ... other columns #' -#' @family string_funcs +#' @family string functions #' @rdname concat #' @name concat #' @aliases concat,Column-method @@ -2191,7 +2313,7 @@ setMethod("concat", #' @param x Column to compute on #' @param ... other columns #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname greatest #' @name greatest #' @aliases greatest,Column-method @@ -2218,7 +2340,7 @@ setMethod("greatest", #' @param x Column to compute on #' @param ... other columns #' -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname least #' @aliases least,Column-method #' @name least @@ -2296,13 +2418,13 @@ setMethod("n", signature(x = "Column"), #' A pattern could be for instance \preformatted{dd.MM.yyyy} and could return a string like '18.03.1993'. All #' pattern letters of \code{java.text.SimpleDateFormat} can be used. #' -#' NOTE: Use when ever possible specialized functions like \code{year}. These benefit from a +#' Note: Use when ever possible specialized functions like \code{year}. These benefit from a #' specialized implementation. #' #' @param y Column to compute on. #' @param x date format specification. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname date_format #' @name date_format #' @aliases date_format,Column,character-method @@ -2315,14 +2437,54 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) +#' from_json +#' +#' Parses a column containing a JSON string into a Column of \code{structType} with the specified +#' \code{schema} or array of \code{structType} if \code{as.json.array} is set to \code{TRUE}. +#' If the string is unparseable, the Column will contains the value NA. +#' +#' @param x Column containing the JSON string. +#' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' @param as.json.array indicating if input string is JSON array of objects or a single object. +#' @param ... additional named properties to control how the json is parsed, accepts the same +#' options as the JSON data source. +#' +#' @family non-aggregate functions +#' @rdname from_json +#' @name from_json +#' @aliases from_json,Column,structType-method +#' @export +#' @examples +#' \dontrun{ +#' schema <- structType(structField("name", "string"), +#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) +#'} +#' @note from_json since 2.2.0 +setMethod("from_json", signature(x = "Column", schema = "structType"), + function(x, schema, as.json.array = FALSE, ...) { + if (as.json.array) { + jschema <- callJStatic("org.apache.spark.sql.types.DataTypes", + "createArrayType", + schema$jobj) + } else { + jschema <- schema$jobj + } + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "from_json", + x@jc, jschema, options) + column(jc) + }) + #' from_utc_timestamp #' -#' Assumes given timestamp is UTC and converts to given timezone. +#' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp +#' that corresponds to the same time of day in the given timezone. #' #' @param y Column to compute on. #' @param x time zone to use. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname from_utc_timestamp #' @name from_utc_timestamp #' @aliases from_utc_timestamp,Column,character-method @@ -2340,12 +2502,12 @@ setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), #' Locate the position of the first occurrence of substr column in the given string. #' Returns null if either of the arguments are null. #' -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param y column to check #' @param x substring to check -#' @family string_funcs +#' @family string functions #' @aliases instr,Column,character-method #' @rdname instr #' @name instr @@ -2372,7 +2534,7 @@ setMethod("instr", signature(y = "Column", x = "character"), #' @param y Column to compute on. #' @param x Day of the week string. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname next_day #' @name next_day #' @aliases next_day,Column,character-method @@ -2391,12 +2553,13 @@ setMethod("next_day", signature(y = "Column", x = "character"), #' to_utc_timestamp #' -#' Assumes given timestamp is in given timezone and converts to UTC. +#' Given a timestamp, which corresponds to a certain time of day in the given timezone, returns +#' another timestamp that corresponds to the same time of day in UTC. #' #' @param y Column to compute on #' @param x timezone to use #' -#' @family datetime_funcs +#' @family date time functions #' @rdname to_utc_timestamp #' @name to_utc_timestamp #' @aliases to_utc_timestamp,Column,character-method @@ -2417,7 +2580,7 @@ setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), #' @param x Number of months to add #' #' @name add_months -#' @family datetime_funcs +#' @family date time functions #' @rdname add_months #' @aliases add_months,Column,numeric-method #' @export @@ -2436,7 +2599,7 @@ setMethod("add_months", signature(y = "Column", x = "numeric"), #' @param y Column to compute on #' @param x Number of days to add #' -#' @family datetime_funcs +#' @family date time functions #' @rdname date_add #' @name date_add #' @aliases date_add,Column,numeric-method @@ -2456,7 +2619,7 @@ setMethod("date_add", signature(y = "Column", x = "numeric"), #' @param y Column to compute on #' @param x Number of days to substract #' -#' @family datetime_funcs +#' @family date time functions #' @rdname date_sub #' @name date_sub #' @aliases date_sub,Column,numeric-method @@ -2471,15 +2634,15 @@ setMethod("date_sub", signature(y = "Column", x = "numeric"), #' format_number #' -#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places, -#' and returns the result as a string column. +#' Formats numeric column y to a format like '#,###,###.##', rounded to x decimal places +#' with HALF_EVEN round mode, and returns the result as a string column. #' #' If x is 0, the result has no decimal point or fractional part. #' If x < 0, the result will be null. #' #' @param y column to format #' @param x number of decimal place to format to -#' @family string_funcs +#' @family string functions #' @rdname format_number #' @name format_number #' @aliases format_number,Column,numeric-method @@ -2501,7 +2664,7 @@ setMethod("format_number", signature(y = "Column", x = "numeric"), #' #' @param y column to compute SHA-2 on. #' @param x one of 224, 256, 384, or 512. -#' @family misc_funcs +#' @family misc functions #' @rdname sha2 #' @name sha2 #' @aliases sha2,Column,numeric-method @@ -2522,7 +2685,7 @@ setMethod("sha2", signature(y = "Column", x = "numeric"), #' @param y column to compute on. #' @param x number of bits to shift. #' -#' @family math_funcs +#' @family math functions #' @rdname shiftLeft #' @name shiftLeft #' @aliases shiftLeft,Column,numeric-method @@ -2539,13 +2702,13 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), #' shiftRight #' -#' Shift the given value numBits right. If the given value is a long value, it will return +#' (Signed) shift the given value numBits right. If the given value is a long value, it will return #' a long value else it will return an integer value. #' #' @param y column to compute on. #' @param x number of bits to shift. #' -#' @family math_funcs +#' @family math functions #' @rdname shiftRight #' @name shiftRight #' @aliases shiftRight,Column,numeric-method @@ -2568,7 +2731,7 @@ setMethod("shiftRight", signature(y = "Column", x = "numeric"), #' @param y column to compute on. #' @param x number of bits to shift. #' -#' @family math_funcs +#' @family math functions #' @rdname shiftRightUnsigned #' @name shiftRightUnsigned #' @aliases shiftRightUnsigned,Column,numeric-method @@ -2592,7 +2755,7 @@ setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), #' @param sep separator to use. #' @param ... other columns to concatenate. #' -#' @family string_funcs +#' @family string functions #' @rdname concat_ws #' @name concat_ws #' @aliases concat_ws,character,Column-method @@ -2614,7 +2777,7 @@ setMethod("concat_ws", signature(sep = "character", x = "Column"), #' @param fromBase base to convert from. #' @param toBase base to convert to. #' -#' @family math_funcs +#' @family math functions #' @rdname conv #' @aliases conv,Column,numeric,numeric-method #' @name conv @@ -2637,7 +2800,7 @@ setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeri #' SparkDataFrame.selectExpr #' #' @param x an expression character object to be parsed. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname expr #' @aliases expr,character-method #' @name expr @@ -2657,7 +2820,7 @@ setMethod("expr", signature(x = "character"), #' @param format a character object of format strings. #' @param x a Column. #' @param ... additional Column(s). -#' @family string_funcs +#' @family string functions #' @rdname format_string #' @name format_string #' @aliases format_string,character,Column-method @@ -2684,7 +2847,7 @@ setMethod("format_string", signature(format = "character", x = "Column"), #' \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ #' Customizing Formats} for available options. #' @param ... further arguments to be passed to or from other methods. -#' @family datetime_funcs +#' @family date time functions #' @rdname from_unixtime #' @name from_unixtime #' @aliases from_unixtime,Column-method @@ -2729,7 +2892,7 @@ setMethod("from_unixtime", signature(x = "Column"), #' @param ... further arguments to be passed to or from other methods. #' @return An output column of struct called 'window' by default with the nested columns 'start' #' and 'end'. -#' @family datetime_funcs +#' @family date time functions #' @rdname window #' @name window #' @aliases window,Column-method @@ -2777,14 +2940,15 @@ setMethod("window", signature(x = "Column"), #' locate #' #' Locate the position of the first occurrence of substr. -#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' +#' Note: The position is not zero based, but 1 based index. Returns 0 if substr #' could not be found in str. #' #' @param substr a character string to be matched. #' @param str a Column where matches are sought for each entry. #' @param pos start position of search. #' @param ... further arguments to be passed to or from other methods. -#' @family string_funcs +#' @family string functions #' @rdname locate #' @aliases locate,character,Column-method #' @name locate @@ -2806,7 +2970,7 @@ setMethod("locate", signature(substr = "character", str = "Column"), #' @param x the string Column to be left-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string_funcs +#' @family string functions #' @rdname lpad #' @aliases lpad,Column,numeric,character-method #' @name lpad @@ -2823,10 +2987,11 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), #' rand #' -#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' Generate a random column with independent and identically distributed (i.i.d.) samples +#' from U[0.0, 1.0]. #' #' @param seed a random seed. Can be missing. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname rand #' @name rand #' @aliases rand,missing-method @@ -2852,10 +3017,11 @@ setMethod("rand", signature(seed = "numeric"), #' randn #' -#' Generate a column with i.i.d. samples from the standard normal distribution. +#' Generate a column with independent and identically distributed (i.i.d.) samples from +#' the standard normal distribution. #' #' @param seed a random seed. Can be missing. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname randn #' @name randn #' @aliases randn,missing-method @@ -2887,7 +3053,7 @@ setMethod("randn", signature(seed = "numeric"), #' @param x a string Column. #' @param pattern a regular expression. #' @param idx a group index. -#' @family string_funcs +#' @family string functions #' @rdname regexp_extract #' @name regexp_extract #' @aliases regexp_extract,Column,character,numeric-method @@ -2910,7 +3076,7 @@ setMethod("regexp_extract", #' @param x a string Column. #' @param pattern a regular expression. #' @param replacement a character string that a matched \code{pattern} is replaced with. -#' @family string_funcs +#' @family string functions #' @rdname regexp_replace #' @name regexp_replace #' @aliases regexp_replace,Column,character,character-method @@ -2933,7 +3099,7 @@ setMethod("regexp_replace", #' @param x the string Column to be right-padded. #' @param len maximum length of each output result. #' @param pad a character string to be padded with. -#' @family string_funcs +#' @family string functions #' @rdname rpad #' @name rpad #' @aliases rpad,Column,numeric,character-method @@ -2960,7 +3126,7 @@ setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), #' @param count number of occurrences of \code{delim} before the substring is returned. #' A positive number means counting from the left, while negative means #' counting from the right. -#' @family string_funcs +#' @family string functions #' @rdname substring_index #' @aliases substring_index,Column,character,numeric-method #' @name substring_index @@ -2992,7 +3158,7 @@ setMethod("substring_index", #' @param replaceString a target string where each \code{matchingString} character will #' be replaced by the character in \code{replaceString} #' at the same location, if any. -#' @family string_funcs +#' @family string functions #' @rdname translate #' @name translate #' @aliases translate,Column,character,character-method @@ -3011,7 +3177,7 @@ setMethod("translate", #' #' Gets current Unix timestamp in seconds. #' -#' @family datetime_funcs +#' @family date time functions #' @rdname unix_timestamp #' @name unix_timestamp #' @aliases unix_timestamp,missing,missing-method @@ -3061,7 +3227,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), #' #' @param condition the condition to test on. Must be a Column expression. #' @param value result expression. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname when #' @name when #' @aliases when,Column-method @@ -3085,7 +3251,7 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @param test a Column expression that describes the condition. #' @param yes return values for \code{TRUE} elements of test. #' @param no return values for \code{FALSE} elements of test. -#' @family normal_funcs +#' @family non-aggregate functions #' @rdname ifelse #' @name ifelse #' @aliases ifelse,Column-method @@ -3123,7 +3289,7 @@ setMethod("ifelse", #' #' @rdname cume_dist #' @name cume_dist -#' @family window_funcs +#' @family window functions #' @aliases cume_dist,missing-method #' @export #' @examples \dontrun{ @@ -3145,13 +3311,14 @@ setMethod("cume_dist", #' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking #' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second -#' place and that the next person came in third. +#' place and that the next person came in third. Rank would give me sequential numbers, making +#' the person that came in third place (after the ties) would register as coming in fifth. #' #' This is equivalent to the \code{DENSE_RANK} function in SQL. #' #' @rdname dense_rank #' @name dense_rank -#' @family window_funcs +#' @family window functions #' @aliases dense_rank,missing-method #' @export #' @examples \dontrun{ @@ -3183,7 +3350,7 @@ setMethod("dense_rank", #' @rdname lag #' @name lag #' @aliases lag,characterOrColumn-method -#' @family window_funcs +#' @family window functions #' @export #' @examples \dontrun{ #' df <- createDataFrame(mtcars) @@ -3225,7 +3392,7 @@ setMethod("lag", #' #' @rdname lead #' @name lead -#' @family window_funcs +#' @family window functions #' @aliases lead,characterOrColumn,numeric-method #' @export #' @examples \dontrun{ @@ -3265,7 +3432,7 @@ setMethod("lead", #' @rdname ntile #' @name ntile #' @aliases ntile,numeric-method -#' @family window_funcs +#' @family window functions #' @export #' @examples \dontrun{ #' df <- createDataFrame(mtcars) @@ -3296,7 +3463,7 @@ setMethod("ntile", #' #' @rdname percent_rank #' @name percent_rank -#' @family window_funcs +#' @family window functions #' @aliases percent_rank,missing-method #' @export #' @examples \dontrun{ @@ -3316,16 +3483,17 @@ setMethod("percent_rank", #' #' Window function: returns the rank of rows within a window partition. #' -#' The difference between rank and denseRank is that denseRank leaves no gaps in ranking -#' sequence when there are ties. That is, if you were ranking a competition using denseRank +#' The difference between rank and dense_rank is that dense_rank leaves no gaps in ranking +#' sequence when there are ties. That is, if you were ranking a competition using dense_rank #' and had three people tie for second place, you would say that all three were in second -#' place and that the next person came in third. +#' place and that the next person came in third. Rank would give me sequential numbers, making +#' the person that came in third place (after the ties) would register as coming in fifth. #' #' This is equivalent to the RANK function in SQL. #' #' @rdname rank #' @name rank -#' @family window_funcs +#' @family window functions #' @aliases rank,missing-method #' @export #' @examples \dontrun{ @@ -3363,7 +3531,7 @@ setMethod("rank", #' @rdname row_number #' @name row_number #' @aliases row_number,missing-method -#' @family window_funcs +#' @family window functions #' @export #' @examples \dontrun{ #' df <- createDataFrame(mtcars) @@ -3382,14 +3550,14 @@ setMethod("row_number", #' array_contains #' -#' Returns true if the array contain the value. +#' Returns null if the array is null, true if the array contains the value, and false otherwise. #' #' @param x A Column #' @param value A value to be checked if contained in the column #' @rdname array_contains #' @aliases array_contains,Column-method #' @name array_contains -#' @family collection_funcs +#' @family collection functions #' @export #' @examples \dontrun{array_contains(df$c, 1)} #' @note array_contains since 1.6.0 @@ -3408,7 +3576,7 @@ setMethod("array_contains", #' #' @rdname explode #' @name explode -#' @family collection_funcs +#' @family collection functions #' @aliases explode,Column-method #' @export #' @examples \dontrun{explode(df$c)} @@ -3429,7 +3597,7 @@ setMethod("explode", #' @rdname size #' @name size #' @aliases size,Column-method -#' @family collection_funcs +#' @family collection functions #' @export #' @examples \dontrun{size(df$c)} #' @note size since 1.5.0 @@ -3442,8 +3610,8 @@ setMethod("size", #' sort_array #' -#' Sorts the input array for the given column in ascending order, -#' according to the natural ordering of the array elements. +#' Sorts the input array in ascending or descending order according +#' to the natural ordering of the array elements. #' #' @param x A Column to sort #' @param asc A logical flag indicating the sorting order. @@ -3452,7 +3620,7 @@ setMethod("size", #' @rdname sort_array #' @name sort_array #' @aliases sort_array,Column-method -#' @family collection_funcs +#' @family collection functions #' @export #' @examples #' \dontrun{ @@ -3475,7 +3643,7 @@ setMethod("sort_array", #' #' @rdname posexplode #' @name posexplode -#' @family collection_funcs +#' @family collection functions #' @aliases posexplode,Column-method #' @export #' @examples \dontrun{posexplode(df$c)} @@ -3486,3 +3654,347 @@ setMethod("posexplode", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc) column(jc) }) + +#' create_array +#' +#' Creates a new array column. The input columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family non-aggregate functions +#' @rdname create_array +#' @name create_array +#' @aliases create_array,Column-method +#' @export +#' @examples \dontrun{create_array(df$x, df$y, df$z)} +#' @note create_array since 2.3.0 +setMethod("create_array", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "array", jcols) + column(jc) + }) + +#' create_map +#' +#' Creates a new map column. The input columns must be grouped as key-value pairs, +#' e.g. (key1, value1, key2, value2, ...). +#' The key columns must all have the same data type, and can't be null. +#' The value columns must all have the same data type. +#' +#' @param x Column to compute on +#' @param ... additional Column(s). +#' +#' @family non-aggregate functions +#' @rdname create_map +#' @name create_map +#' @aliases create_map,Column-method +#' @export +#' @examples \dontrun{create_map(lit("x"), lit(1.0), lit("y"), lit(-1.0))} +#' @note create_map since 2.3.0 +setMethod("create_map", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols) + column(jc) + }) + +#' collect_list +#' +#' Creates a list of objects with duplicates. +#' +#' @param x Column to compute on +#' +#' @rdname collect_list +#' @name collect_list +#' @family aggregate functions +#' @aliases collect_list,Column-method +#' @export +#' @examples \dontrun{collect_list(df$x)} +#' @note collect_list since 2.3.0 +setMethod("collect_list", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc) + column(jc) + }) + +#' collect_set +#' +#' Creates a list of objects with duplicate elements eliminated. +#' +#' @param x Column to compute on +#' +#' @rdname collect_set +#' @name collect_set +#' @family aggregate functions +#' @aliases collect_set,Column-method +#' @export +#' @examples \dontrun{collect_set(df$x)} +#' @note collect_set since 2.3.0 +setMethod("collect_set", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc) + column(jc) + }) + +#' split_string +#' +#' Splits string on regular expression. +#' +#' Equivalent to \code{split} SQL function +#' +#' @param x Column to compute on +#' @param pattern Java regular expression +#' +#' @rdname split_string +#' @family string functions +#' @aliases split_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, split_string(df$value, "\\s+"))) +#' +#' # This is equivalent to the following SQL expression +#' head(selectExpr(df, "split(value, '\\\\s+')")) +#' } +#' @note split_string 2.3.0 +setMethod("split_string", + signature(x = "Column", pattern = "character"), + function(x, pattern) { + jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + column(jc) + }) + +#' repeat_string +#' +#' Repeats string n times. +#' +#' Equivalent to \code{repeat} SQL function +#' +#' @param x Column to compute on +#' @param n Number of repetitions +#' +#' @rdname repeat_string +#' @family string functions +#' @aliases repeat_string,Column-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' first(select(df, repeat_string(df$value, 3))) +#' +#' # This is equivalent to the following SQL expression +#' first(selectExpr(df, "repeat(value, 3)")) +#' } +#' @note repeat_string since 2.3.0 +setMethod("repeat_string", + signature(x = "Column", n = "numeric"), + function(x, n) { + jc <- callJStatic("org.apache.spark.sql.functions", "repeat", x@jc, numToInt(n)) + column(jc) + }) + +#' explode_outer +#' +#' Creates a new row for each element in the given array or map column. +#' Unlike \code{explode}, if the array/map is \code{null} or empty +#' then \code{null} is produced. +#' +#' @param x Column to compute on +#' +#' @rdname explode_outer +#' @name explode_outer +#' @family collection functions +#' @aliases explode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, explode_outer(split_string(df$text, ",")))) +#' } +#' @note explode_outer since 2.3.0 +setMethod("explode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "explode_outer", x@jc) + column(jc) + }) + +#' posexplode_outer +#' +#' Creates a new row for each element with position in the given array or map column. +#' Unlike \code{posexplode}, if the array/map is \code{null} or empty +#' then the row (\code{null}, \code{null}) is produced. +#' +#' @param x Column to compute on +#' +#' @rdname posexplode_outer +#' @name posexplode_outer +#' @family collection functions +#' @aliases posexplode_outer,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' id = c(1, 2, 3), text = c("a,b,c", NA, "d,e") +#' )) +#' +#' head(select(df, df$id, posexplode_outer(split_string(df$text, ",")))) +#' } +#' @note posexplode_outer since 2.3.0 +setMethod("posexplode_outer", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) + column(jc) + }) + +#' not +#' +#' Inversion of boolean expression. +#' +#' \code{not} and \code{!} cannot be applied directly to numerical column. +#' To achieve R-like truthiness column has to be casted to \code{BooleanType}. +#' +#' @param x Column to compute on +#' @rdname not +#' @name not +#' @aliases not,Column-method +#' @family non-aggregate functions +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' is_true = c(TRUE, FALSE, NA), +#' flag = c(1, 0, 1) +#' )) +#' +#' head(select(df, not(df$is_true))) +#' +#' # Explicit cast is required when working with numeric column +#' head(select(df, not(cast(df$flag, "boolean")))) +#' } +#' @note not since 2.3.0 +setMethod("not", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) + column(jc) + }) + +#' grouping_bit +#' +#' Indicates whether a specified column in a GROUP BY list is aggregated or not, +#' returns 1 for aggregated or 0 for not aggregated in the result set. +#' +#' Same as \code{GROUPING} in SQL and \code{grouping} function in Scala. +#' +#' @param x Column to compute on +#' +#' @rdname grouping_bit +#' @name grouping_bit +#' @family aggregate functions +#' @aliases grouping_bit,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_bit(df$cyl), grouping_bit(df$gear), grouping_bit(df$am) +#' ) +#' } +#' @note grouping_bit since 2.3.0 +setMethod("grouping_bit", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "grouping", x@jc) + column(jc) + }) + +#' grouping_id +#' +#' Returns the level of grouping. +#' +#' Equals to \code{ +#' grouping_bit(c1) * 2^(n - 1) + grouping_bit(c2) * 2^(n - 2) + ... + grouping_bit(cn) +#' } +#' +#' @param x Column to compute on +#' @param ... additional Column(s) (optional). +#' +#' @rdname grouping_id +#' @name grouping_id +#' @family aggregate functions +#' @aliases grouping_id,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(mtcars) +#' +#' # With cube +#' agg( +#' cube(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' +#' # With rollup +#' agg( +#' rollup(df, "cyl", "gear", "am"), +#' mean(df$mpg), +#' grouping_id(df$cyl, df$gear, df$am) +#' ) +#' } +#' @note grouping_id since 2.3.0 +setMethod("grouping_id", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "grouping_id", jcols) + column(jc) + }) + +#' input_file_name +#' +#' Creates a string column with the input file name for a given row +#' +#' @rdname input_file_name +#' @name input_file_name +#' @family non-aggregate functions +#' @aliases input_file_name,missing-method +#' @export +#' @examples \dontrun{ +#' df <- read.text("README.md") +#' +#' head(select(df, input_file_name())) +#' } +#' @note input_file_name since 2.3.0 +setMethod("input_file_name", signature("missing"), + function() { + jc <- callJStatic("org.apache.spark.sql.functions", "input_file_name") + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0271b26a10a90..514ca99d45cd3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -28,11 +28,11 @@ setGeneric("cacheRDD", function(x) { standardGeneric("cacheRDD") }) # @rdname coalesce # @seealso repartition # @export -setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") }) +setGeneric("coalesceRDD", function(x, numPartitions, ...) { standardGeneric("coalesceRDD") }) # @rdname checkpoint-methods # @export -setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") }) +setGeneric("checkpointRDD", function(x) { standardGeneric("checkpointRDD") }) setGeneric("collectRDD", function(x, ...) { standardGeneric("collectRDD") }) @@ -66,7 +66,7 @@ setGeneric("freqItems", function(x, cols, support = 0.01) { standardGeneric("fre # @rdname approxQuantile # @export setGeneric("approxQuantile", - function(x, col, probabilities, relativeError) { + function(x, cols, probabilities, relativeError) { standardGeneric("approxQuantile") }) @@ -138,9 +138,9 @@ setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") }) # @export setGeneric("name", function(x) { standardGeneric("name") }) -# @rdname getNumPartitions +# @rdname getNumPartitionsRDD # @export -setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) +setGeneric("getNumPartitionsRDD", function(x) { standardGeneric("getNumPartitionsRDD") }) # @rdname getNumPartitions # @export @@ -387,6 +387,17 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @export setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +#' alias +#' +#' Returns a new SparkDataFrame or a Column with an alias set. Equivalent to SQL "AS" keyword. +#' +#' @name alias +#' @rdname alias +#' @param object x a SparkDataFrame or a Column +#' @param data new name to use +#' @return a SparkDataFrame or a Column +NULL + #' @rdname arrange #' @export setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) @@ -406,6 +417,17 @@ setGeneric("attach") #' @export setGeneric("cache", function(x) { standardGeneric("cache") }) +#' @rdname checkpoint +#' @export +setGeneric("checkpoint", function(x, eager = TRUE) { standardGeneric("checkpoint") }) + +#' @rdname coalesce +#' @param x a Column or a SparkDataFrame. +#' @param ... additional argument(s). If \code{x} is a Column, additional Columns can be optionally +#' provided. +#' @export +setGeneric("coalesce", function(x, ...) { standardGeneric("coalesce") }) + #' @rdname collect #' @export setGeneric("collect", function(x, ...) { standardGeneric("collect") }) @@ -472,6 +494,10 @@ setGeneric("createOrReplaceTempView", # @export setGeneric("crossJoin", function(x, y) { standardGeneric("crossJoin") }) +#' @rdname cube +#' @export +setGeneric("cube", function(x, ...) { standardGeneric("cube") }) + #' @rdname dapply #' @export setGeneric("dapply", function(x, func, schema) { standardGeneric("dapply") }) @@ -492,6 +518,10 @@ setGeneric("gapply", function(x, ...) { standardGeneric("gapply") }) #' @export setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") }) +# @rdname getNumPartitions +# @export +setGeneric("getNumPartitions", function(x) { standardGeneric("getNumPartitions") }) + #' @rdname summary #' @export setGeneric("describe", function(x, col, ...) { standardGeneric("describe") }) @@ -528,6 +558,9 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") }) #' @rdname explain #' @export +#' @param x a SparkDataFrame or a StreamingQuery. +#' @param extended Logical. If extended is FALSE, prints only the physical plan. +#' @param ... further arguments to be passed to or from other methods. setGeneric("explain", function(x, ...) { standardGeneric("explain") }) #' @rdname except @@ -554,6 +587,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) @@ -566,6 +603,10 @@ setGeneric("intersect", function(x, y) { standardGeneric("intersect") }) #' @export setGeneric("isLocal", function(x) { standardGeneric("isLocal") }) +#' @rdname isStreaming +#' @export +setGeneric("isStreaming", function(x) { standardGeneric("isStreaming") }) + #' @rdname limit #' @export setGeneric("limit", function(x, num) {standardGeneric("limit") }) @@ -609,6 +650,10 @@ setGeneric("sample", standardGeneric("sample") }) +#' @rdname rollup +#' @export +setGeneric("rollup", function(x, ...) { standardGeneric("rollup") }) + #' @rdname sample #' @export setGeneric("sample_frac", @@ -671,6 +716,12 @@ setGeneric("write.parquet", function(x, path, ...) { #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +#' @rdname write.stream +#' @export +setGeneric("write.stream", function(df, source = NULL, outputMode = NULL, ...) { + standardGeneric("write.stream") +}) + #' @rdname write.text #' @export setGeneric("write.text", function(x, path, ...) { standardGeneric("write.text") }) @@ -748,6 +799,10 @@ setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.d #' @export setGeneric("randomSplit", function(x, weights, seed) { standardGeneric("randomSplit") }) +#' @rdname broadcast +#' @export +setGeneric("broadcast", function(x) { standardGeneric("broadcast") }) + ###################### Column Methods ########################## #' @rdname columnfunctions @@ -820,6 +875,10 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @export setGeneric("over", function(x, window) { standardGeneric("over") }) +#' @rdname eq_null_safe +#' @export +setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) + ###################### WindowSpec Methods ########################## #' @rdname partitionBy @@ -890,6 +949,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) #' @export setGeneric("ceil", function(x) { standardGeneric("ceil") }) +#' @rdname collect_list +#' @export +setGeneric("collect_list", function(x) { standardGeneric("collect_list") }) + +#' @rdname collect_set +#' @export +setGeneric("collect_set", function(x) { standardGeneric("collect_set") }) + #' @rdname column #' @export setGeneric("column", function(x) { standardGeneric("column") }) @@ -914,6 +981,14 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname create_array +#' @export +setGeneric("create_array", function(x, ...) { standardGeneric("create_array") }) + +#' @rdname create_map +#' @export +setGeneric("create_map", function(x, ...) { standardGeneric("create_map") }) + #' @rdname hash #' @export setGeneric("hash", function(x, ...) { standardGeneric("hash") }) @@ -964,6 +1039,10 @@ setGeneric("encode", function(x, charset) { standardGeneric("encode") }) #' @export setGeneric("explode", function(x) { standardGeneric("explode") }) +#' @rdname explode_outer +#' @export +setGeneric("explode_outer", function(x) { standardGeneric("explode_outer") }) + #' @rdname expr #' @export setGeneric("expr", function(x) { standardGeneric("expr") }) @@ -980,6 +1059,10 @@ setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @export setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) +#' @rdname from_json +#' @export +setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) + #' @rdname from_unixtime #' @export setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) @@ -988,6 +1071,14 @@ setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") #' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) +#' @rdname grouping_bit +#' @export +setGeneric("grouping_bit", function(x) { standardGeneric("grouping_bit") }) + +#' @rdname grouping_id +#' @export +setGeneric("grouping_id", function(x, ...) { standardGeneric("grouping_id") }) + #' @rdname hex #' @export setGeneric("hex", function(x) { standardGeneric("hex") }) @@ -1004,6 +1095,12 @@ setGeneric("hypot", function(y, x) { standardGeneric("hypot") }) #' @export setGeneric("initcap", function(x) { standardGeneric("initcap") }) +#' @param x empty. Should be used with no argument. +#' @rdname input_file_name +#' @export +setGeneric("input_file_name", + function(x = "missing") { standardGeneric("input_file_name") }) + #' @rdname instr #' @export setGeneric("instr", function(y, x) { standardGeneric("instr") }) @@ -1094,6 +1191,10 @@ setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @export setGeneric("negate", function(x) { standardGeneric("negate") }) +#' @rdname not +#' @export +setGeneric("not", function(x) { standardGeneric("not") }) + #' @rdname next_day #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) @@ -1119,6 +1220,10 @@ setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @export setGeneric("posexplode", function(x) { standardGeneric("posexplode") }) +#' @rdname posexplode_outer +#' @export +setGeneric("posexplode_outer", function(x) { standardGeneric("posexplode_outer") }) + #' @rdname quarter #' @export setGeneric("quarter", function(x) { standardGeneric("quarter") }) @@ -1144,6 +1249,10 @@ setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp setGeneric("regexp_replace", function(x, pattern, replacement) { standardGeneric("regexp_replace") }) +#' @rdname repeat_string +#' @export +setGeneric("repeat_string", function(x, n) { standardGeneric("repeat_string") }) + #' @rdname reverse #' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) @@ -1209,6 +1318,10 @@ setGeneric("skewness", function(x) { standardGeneric("skewness") }) #' @export setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") }) +#' @rdname split_string +#' @export +setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) + #' @rdname soundex #' @export setGeneric("soundex", function(x) { standardGeneric("soundex") }) @@ -1252,7 +1365,15 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname to_date #' @export -setGeneric("to_date", function(x) { standardGeneric("to_date") }) +setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) + +#' @rdname to_json +#' @export +setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) + +#' @rdname to_timestamp +#' @export +setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) #' @rdname to_utc_timestamp #' @export @@ -1310,6 +1431,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") }) #' @export setGeneric("year", function(x) { standardGeneric("year") }) + ###################### Spark.ML Methods ########################## #' @rdname fitted @@ -1338,11 +1460,20 @@ setGeneric("rbind", signature = "...") #' @export setGeneric("spark.als", function(data, ...) { standardGeneric("spark.als") }) +#' @rdname spark.bisectingKmeans +#' @export +setGeneric("spark.bisectingKmeans", + function(data, formula, ...) { standardGeneric("spark.bisectingKmeans") }) + #' @rdname spark.gaussianMixture #' @export setGeneric("spark.gaussianMixture", function(data, formula, ...) { standardGeneric("spark.gaussianMixture") }) +#' @rdname spark.gbt +#' @export +setGeneric("spark.gbt", function(data, formula, ...) { standardGeneric("spark.gbt") }) + #' @rdname spark.glm #' @export setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.glm") }) @@ -1369,7 +1500,7 @@ setGeneric("spark.logit", function(data, formula, ...) { standardGeneric("spark. #' @rdname spark.mlp #' @export -setGeneric("spark.mlp", function(data, ...) { standardGeneric("spark.mlp") }) +setGeneric("spark.mlp", function(data, formula, ...) { standardGeneric("spark.mlp") }) #' @rdname spark.naiveBayes #' @export @@ -1382,7 +1513,11 @@ setGeneric("spark.randomForest", #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) + +#' @rdname spark.svmLinear +#' @export +setGeneric("spark.svmLinear", function(data, formula, ...) { standardGeneric("spark.svmLinear") }) #' @rdname spark.lda #' @export @@ -1392,6 +1527,17 @@ setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark #' @export setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") }) +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.fpGrowth", function(data, ...) { standardGeneric("spark.fpGrowth") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.freqItemsets", function(object) { standardGeneric("spark.freqItemsets") }) + +#' @rdname spark.fpGrowth +#' @export +setGeneric("spark.associationRules", function(object) { standardGeneric("spark.associationRules") }) #' @param object a fitted ML model object. #' @param path the directory where the model is saved. @@ -1399,3 +1545,30 @@ setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.p #' @rdname write.ml #' @export setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") }) + + +###################### Streaming Methods ########################## + +#' @rdname awaitTermination +#' @export +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) + +#' @rdname isActive +#' @export +setGeneric("isActive", function(x) { standardGeneric("isActive") }) + +#' @rdname lastProgress +#' @export +setGeneric("lastProgress", function(x) { standardGeneric("lastProgress") }) + +#' @rdname queryName +#' @export +setGeneric("queryName", function(x) { standardGeneric("queryName") }) + +#' @rdname status +#' @export +setGeneric("status", function(x) { standardGeneric("status") }) + +#' @rdname stopQuery +#' @export +setGeneric("stopQuery", function(x) { standardGeneric("stopQuery") }) diff --git a/R/pkg/R/install.R b/R/pkg/R/install.R index 69b0a523b84e4..4ca7aa664e023 100644 --- a/R/pkg/R/install.R +++ b/R/pkg/R/install.R @@ -21,9 +21,9 @@ #' Download and Install Apache Spark to a Local Directory #' #' \code{install.spark} downloads and installs Spark to a local directory if -#' it is not found. The Spark version we use is the same as the SparkR version. -#' Users can specify a desired Hadoop version, the remote mirror site, and -#' the directory where the package is installed locally. +#' it is not found. If SPARK_HOME is set in the environment, and that directory is found, that is +#' returned. The Spark version we use is the same as the SparkR version. Users can specify a desired +#' Hadoop version, the remote mirror site, and the directory where the package is installed locally. #' #' The full url of remote file is inferred from \code{mirrorUrl} and \code{hadoopVersion}. #' \code{mirrorUrl} specifies the remote path to a Spark folder. It is followed by a subfolder @@ -50,11 +50,11 @@ #' \itemize{ #' \item Mac OS X: \file{~/Library/Caches/spark} #' \item Unix: \env{$XDG_CACHE_HOME} if defined, otherwise \file{~/.cache/spark} -#' \item Windows: \file{\%LOCALAPPDATA\%\\spark\\spark\\Cache}. +#' \item Windows: \file{\%LOCALAPPDATA\%\\Apache\\Spark\\Cache}. #' } #' @param overwrite If \code{TRUE}, download and overwrite the existing tar file in localDir #' and force re-install Spark (in case the local directory or file is corrupted) -#' @return \code{install.spark} returns the local directory where Spark is found or installed +#' @return the (invisible) local directory where Spark is found or installed #' @rdname install.spark #' @name install.spark #' @aliases install.spark @@ -68,6 +68,16 @@ #' \href{http://spark.apache.org/downloads.html}{Apache Spark} install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, localDir = NULL, overwrite = FALSE) { + sparkHome <- Sys.getenv("SPARK_HOME") + if (isSparkRShell()) { + stopifnot(nchar(sparkHome) > 0) + message("Spark is already running in sparkR shell.") + return(invisible(sparkHome)) + } else if (!is.na(file.info(sparkHome)$isdir)) { + message("Spark package found in SPARK_HOME: ", sparkHome) + return(invisible(sparkHome)) + } + version <- paste0("spark-", packageVersion("SparkR")) hadoopVersion <- tolower(hadoopVersion) hadoopVersionName <- hadoopVersionName(hadoopVersion) @@ -79,19 +89,28 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, dir.create(localDir, recursive = TRUE) } - packageLocalDir <- file.path(localDir, packageName) - if (overwrite) { message(paste0("Overwrite = TRUE: download and overwrite the tar file", "and Spark package directory if they exist.")) } + releaseUrl <- Sys.getenv("SPARKR_RELEASE_DOWNLOAD_URL") + if (releaseUrl != "") { + packageName <- basenameSansExtFromUrl(releaseUrl) + } + + packageLocalDir <- file.path(localDir, packageName) + # can use dir.exists(packageLocalDir) under R 3.2.0 or later if (!is.na(file.info(packageLocalDir)$isdir) && !overwrite) { - fmt <- "%s for Hadoop %s found, with SPARK_HOME set to %s" - msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), - packageLocalDir) - message(msg) + if (releaseUrl != "") { + message(paste(packageName, "found, setting SPARK_HOME to", packageLocalDir)) + } else { + fmt <- "%s for Hadoop %s found, setting SPARK_HOME to %s" + msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), + packageLocalDir) + message(msg) + } Sys.setenv(SPARK_HOME = packageLocalDir) return(invisible(packageLocalDir)) } else { @@ -104,14 +123,37 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, if (tarExists && !overwrite) { message("tar file found.") } else { - robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + if (releaseUrl != "") { + message("Downloading from alternate URL:\n- ", releaseUrl) + success <- downloadUrl(releaseUrl, packageLocalPath) + if (!success) { + unlink(packageLocalPath) + stop(paste0("Fetch failed from ", releaseUrl)) + } + } else { + robustDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) + } } message(sprintf("Installing to %s", localDir)) - untar(tarfile = packageLocalPath, exdir = localDir) - if (!tarExists || overwrite) { + # There are two ways untar can fail - untar could stop() on errors like incomplete block on file + # or, tar command can return failure code + success <- tryCatch(untar(tarfile = packageLocalPath, exdir = localDir) == 0, + error = function(e) { + message(e) + message() + FALSE + }, + warning = function(w) { + # Treat warning as error, add an empty line with message() + message(w) + message() + FALSE + }) + if (!tarExists || overwrite || !success) { unlink(packageLocalPath) } + if (!success) stop("Extract archive failed.") message("DONE.") Sys.setenv(SPARK_HOME = packageLocalDir) message(paste("SPARK_HOME set to", packageLocalDir)) @@ -121,8 +163,7 @@ install.spark <- function(hadoopVersion = "2.7", mirrorUrl = NULL, robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { # step 1: use user-provided url if (!is.null(mirrorUrl)) { - msg <- sprintf("Use user-provided mirror site: %s.", mirrorUrl) - message(msg) + message("Use user-provided mirror site: ", mirrorUrl) success <- directDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) if (success) { @@ -142,7 +183,7 @@ robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, pa packageName, packageLocalPath) if (success) return() } else { - message("Unable to find preferred mirror site.") + message("Unable to download from preferred mirror site: ", mirrorUrl) } # step 3: use backup option @@ -151,8 +192,11 @@ robustDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, pa success <- directDownloadTar(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) if (success) { - return(packageLocalPath) + return() } else { + # remove any partially downloaded file + unlink(packageLocalPath) + message("Unable to download from default mirror site: ", mirrorUrl) msg <- sprintf(paste("Unable to download Spark %s for Hadoop %s.", "Please check network connection, Hadoop version,", "or provide other mirror sites."), @@ -182,17 +226,25 @@ getPreferredMirror <- function(version, packageName) { } directDownloadTar <- function(mirrorUrl, version, hadoopVersion, packageName, packageLocalPath) { - packageRemotePath <- paste0( - file.path(mirrorUrl, version, packageName), ".tgz") + packageRemotePath <- paste0(file.path(mirrorUrl, version, packageName), ".tgz") fmt <- "Downloading %s for Hadoop %s from:\n- %s" msg <- sprintf(fmt, version, ifelse(hadoopVersion == "without", "Free build", hadoopVersion), packageRemotePath) message(msg) + downloadUrl(packageRemotePath, packageLocalPath) +} - isFail <- tryCatch(download.file(packageRemotePath, packageLocalPath), +downloadUrl <- function(remotePath, localPath) { + isFail <- tryCatch(download.file(remotePath, localPath), error = function(e) { - message(sprintf("Fetch failed from %s", mirrorUrl)) - print(e) + message(e) + message() + TRUE + }, + warning = function(w) { + # Treat warning as error, add an empty line with message() + message(w) + message() TRUE }) !isFail @@ -218,12 +270,11 @@ sparkCachePath <- function() { if (.Platform$OS.type == "windows") { winAppPath <- Sys.getenv("LOCALAPPDATA", unset = NA) if (is.na(winAppPath)) { - msg <- paste("%LOCALAPPDATA% not found.", + stop(paste("%LOCALAPPDATA% not found.", "Please define the environment variable", - "or restart and enter an installation path in localDir.") - stop(msg) + "or restart and enter an installation path in localDir.")) } else { - path <- file.path(winAppPath, "spark", "spark", "Cache") + path <- file.path(winAppPath, "Apache", "Spark", "Cache") } } else if (.Platform$OS.type == "unix") { if (Sys.info()["sysname"] == "Darwin") { diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R deleted file mode 100644 index 7a220b8d53a2f..0000000000000 --- a/R/pkg/R/mllib.R +++ /dev/null @@ -1,1867 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# mllib.R: Provides methods for MLlib integration - -# Integration with R's standard functions. -# Most of MLlib's argorithms are provided in two flavours: -# - a specialization of the default R methods (glm). These methods try to respect -# the inputs and the outputs of R's method to the largest extent, but some small differences -# may exist. -# - a set of methods that reflect the arguments of the other languages supported by Spark. These -# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc. - -#' S4 class that represents a generalized linear model -#' -#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper -#' @export -#' @note GeneralizedLinearRegressionModel since 2.0.0 -setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a NaiveBayesModel -#' -#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper -#' @export -#' @note NaiveBayesModel since 2.0.0 -setClass("NaiveBayesModel", representation(jobj = "jobj")) - -#' S4 class that represents an LDAModel -#' -#' @param jobj a Java object reference to the backing Scala LDAWrapper -#' @export -#' @note LDAModel since 2.1.0 -setClass("LDAModel", representation(jobj = "jobj")) - -#' S4 class that represents a AFTSurvivalRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper -#' @export -#' @note AFTSurvivalRegressionModel since 2.0.0 -setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a KMeansModel -#' -#' @param jobj a Java object reference to the backing Scala KMeansModel -#' @export -#' @note KMeansModel since 2.0.0 -setClass("KMeansModel", representation(jobj = "jobj")) - -#' S4 class that represents a MultilayerPerceptronClassificationModel -#' -#' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper -#' @export -#' @note MultilayerPerceptronClassificationModel since 2.1.0 -setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) - -#' S4 class that represents an IsotonicRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel -#' @export -#' @note IsotonicRegressionModel since 2.1.0 -setClass("IsotonicRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a GaussianMixtureModel -#' -#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel -#' @export -#' @note GaussianMixtureModel since 2.1.0 -setClass("GaussianMixtureModel", representation(jobj = "jobj")) - -#' S4 class that represents an ALSModel -#' -#' @param jobj a Java object reference to the backing Scala ALSWrapper -#' @export -#' @note ALSModel since 2.1.0 -setClass("ALSModel", representation(jobj = "jobj")) - -#' S4 class that represents an KSTest -#' -#' @param jobj a Java object reference to the backing Scala KSTestWrapper -#' @export -#' @note KSTest since 2.1.0 -setClass("KSTest", representation(jobj = "jobj")) - -#' S4 class that represents an LogisticRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel -#' @export -#' @note LogisticRegressionModel since 2.1.0 -setClass("LogisticRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a RandomForestRegressionModel -#' -#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel -#' @export -#' @note RandomForestRegressionModel since 2.1.0 -setClass("RandomForestRegressionModel", representation(jobj = "jobj")) - -#' S4 class that represents a RandomForestClassificationModel -#' -#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel -#' @export -#' @note RandomForestClassificationModel since 2.1.0 -setClass("RandomForestClassificationModel", representation(jobj = "jobj")) - -#' Saves the MLlib model to the input path -#' -#' Saves the MLlib model to the input path. For more information, see the specific -#' MLlib model below. -#' @rdname write.ml -#' @name write.ml -#' @export -#' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.lda}, \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.randomForest}, \link{spark.survreg}, -#' @seealso \link{read.ml} -NULL - -#' Makes predictions from a MLlib model -#' -#' Makes predictions from a MLlib model. For more information, see the specific -#' MLlib model below. -#' @rdname predict -#' @name predict -#' @export -#' @seealso \link{spark.glm}, \link{glm}, -#' @seealso \link{spark.als}, \link{spark.gaussianMixture}, \link{spark.isoreg}, \link{spark.kmeans}, -#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, -#' @seealso \link{spark.randomForest}, \link{spark.survreg} -NULL - -write_internal <- function(object, path, overwrite = FALSE) { - writer <- callJMethod(object@jobj, "write") - if (overwrite) { - writer <- callJMethod(writer, "overwrite") - } - invisible(callJMethod(writer, "save", path)) -} - -predict_internal <- function(object, newData) { - dataFrame(callJMethod(object@jobj, "transform", newData@sdf)) -} - -#' Generalized Linear Models -#' -#' Fits generalized linear model against a Spark DataFrame. -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param family a description of the error distribution and link function to be used in the model. -#' This can be a character string naming a family function, a family function or -#' the result of a call to a family function. Refer R family at -#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param tol positive convergence tolerance of iterations. -#' @param maxIter integer giving the maximal number of IRLS iterations. -#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance -#' weights as 1.0. -#' @param regParam regularization parameter for L2 regularization. -#' @param ... additional arguments passed to the method. -#' @aliases spark.glm,SparkDataFrame,formula-method -#' @return \code{spark.glm} returns a fitted generalized linear model -#' @rdname spark.glm -#' @name spark.glm -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") -#' summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.glm since 2.0.0 -#' @seealso \link{glm}, \link{read.ml} -setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, - regParam = 0.0) { - if (is.character(family)) { - family <- get(family, mode = "function", envir = parent.frame()) - } - if (is.function(family)) { - family <- family() - } - if (is.null(family$family)) { - print(family) - stop("'family' not recognized") - } - - formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" - } - - jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) - new("GeneralizedLinearRegressionModel", jobj = jobj) - }) - -#' Generalized Linear Models (R-compliant) -#' -#' Fits a generalized linear model, similarly to R's glm(). -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data a SparkDataFrame or R's glm data for training. -#' @param family a description of the error distribution and link function to be used in the model. -#' This can be a character string naming a family function, a family function or -#' the result of a call to a family function. Refer R family at -#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. -#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance -#' weights as 1.0. -#' @param epsilon positive convergence tolerance of iterations. -#' @param maxit integer giving the maximal number of IRLS iterations. -#' @return \code{glm} returns a fitted generalized linear model. -#' @rdname glm -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") -#' summary(model) -#' } -#' @note glm since 1.5.0 -#' @seealso \link{spark.glm} -setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), - function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { - spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) - }) - -# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). - -#' @param object a fitted generalized linear model. -#' @return \code{summary} returns a summary object of the fitted model, a list of components -#' including at least the coefficients, null/residual deviance, null/residual degrees -#' of freedom, AIC and number of iterations IRLS takes. -#' -#' @rdname spark.glm -#' @export -#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 -setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - features <- callJMethod(jobj, "rFeatures") - coefficients <- callJMethod(jobj, "rCoefficients") - dispersion <- callJMethod(jobj, "rDispersion") - null.deviance <- callJMethod(jobj, "rNullDeviance") - deviance <- callJMethod(jobj, "rDeviance") - df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") - df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") - aic <- callJMethod(jobj, "rAic") - iter <- callJMethod(jobj, "rNumIterations") - family <- callJMethod(jobj, "rFamily") - deviance.resid <- if (is.loaded) { - NULL - } else { - dataFrame(callJMethod(jobj, "rDevianceResiduals")) - } - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) - ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, - dispersion = dispersion, null.deviance = null.deviance, - deviance = deviance, df.null = df.null, df.residual = df.residual, - aic = aic, iter = iter, family = family, is.loaded = is.loaded) - class(ans) <- "summary.GeneralizedLinearRegressionModel" - ans - }) - -# Prints the summary of GeneralizedLinearRegressionModel - -#' @rdname spark.glm -#' @param x summary object of fitted generalized linear model returned by \code{summary} function -#' @export -#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 -print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { - if (x$is.loaded) { - cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n") - } else { - x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", - c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max")) - x$deviance.resid <- zapsmall(x$deviance.resid, 5L) - cat("\nDeviance Residuals: \n") - cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") - print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) - } - - cat("\nCoefficients:\n") - print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L) - - cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion), - ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"), - format(unlist(x[c("null.deviance", "deviance")]), digits = 5L), - " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"), - 1L, paste, collapse = " "), sep = "") - cat("AIC: ", format(x$aic, digits = 4L), "\n\n", - "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "") - invisible(x) - } - -# Makes predictions from a generalized linear model produced by glm() or spark.glm(), -# similarly to R's predict(). - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction" -#' @rdname spark.glm -#' @export -#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 -setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(), -# similarly to R package e1071's predict. - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" -#' @rdname spark.naiveBayes -#' @export -#' @note predict(NaiveBayesModel) since 2.0.0 -setMethod("predict", signature(object = "NaiveBayesModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} - -#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. -#' @return \code{summary} returns a list containing \code{apriori}, the label distribution, and -#' \code{tables}, conditional probabilities given the target label. -#' @rdname spark.naiveBayes -#' @export -#' @note summary(NaiveBayesModel) since 2.0.0 -setMethod("summary", signature(object = "NaiveBayesModel"), - function(object) { - jobj <- object@jobj - features <- callJMethod(jobj, "features") - labels <- callJMethod(jobj, "labels") - apriori <- callJMethod(jobj, "apriori") - apriori <- t(as.matrix(unlist(apriori))) - colnames(apriori) <- unlist(labels) - tables <- callJMethod(jobj, "tables") - tables <- matrix(tables, nrow = length(labels)) - rownames(tables) <- unlist(labels) - colnames(tables) <- unlist(features) - list(apriori = apriori, tables = tables) - }) - -# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() - -#' @param newData A SparkDataFrame for testing -#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities -#' vectors named "topicDistribution" -#' @rdname spark.lda -#' @aliases spark.posterior,LDAModel,SparkDataFrame-method -#' @export -#' @note spark.posterior(LDAModel) since 2.1.0 -setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda} - -#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. -#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. -#' @return \code{summary} returns a list containing -#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for -#' the prior placed on documents distributions over topics \code{theta}} -#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or -#' \code{eta} for the prior placed on topic distributions over terms} -#' \item{\code{logLikelihood}}{log likelihood of the entire corpus} -#' \item{\code{logPerplexity}}{log perplexity} -#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model} -#' \item{\code{vocabSize}}{number of terms in the corpus} -#' \item{\code{topics}}{top 10 terms and their weights of all topics} -#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file -#' used as training set} -#' @rdname spark.lda -#' @aliases summary,LDAModel-method -#' @export -#' @note summary(LDAModel) since 2.1.0 -setMethod("summary", signature(object = "LDAModel"), - function(object, maxTermsPerTopic) { - maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic)) - jobj <- object@jobj - docConcentration <- callJMethod(jobj, "docConcentration") - topicConcentration <- callJMethod(jobj, "topicConcentration") - logLikelihood <- callJMethod(jobj, "logLikelihood") - logPerplexity <- callJMethod(jobj, "logPerplexity") - isDistributed <- callJMethod(jobj, "isDistributed") - vocabSize <- callJMethod(jobj, "vocabSize") - topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic)) - vocabulary <- callJMethod(jobj, "vocabulary") - list(docConcentration = unlist(docConcentration), - topicConcentration = topicConcentration, - logLikelihood = logLikelihood, logPerplexity = logPerplexity, - isDistributed = isDistributed, vocabSize = vocabSize, - topics = topics, vocabulary = unlist(vocabulary)) - }) - -# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda} - -#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log -#' perplexity of the training data if missing argument "data". -#' @rdname spark.lda -#' @aliases spark.perplexity,LDAModel-method -#' @export -#' @note spark.perplexity(LDAModel) since 2.1.0 -setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), - function(object, data) { - ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"), - callJMethod(object@jobj, "computeLogPerplexity", data@sdf)) - }) - -# Saves the Latent Dirichlet Allocation model to the input path. - -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.lda -#' @aliases write.ml,LDAModel,character-method -#' @export -#' @seealso \link{read.ml} -#' @note write.ml(LDAModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "LDAModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' Isotonic Regression Model -#' -#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg(). -#' Users can print, make predictions on the produced model and save the model to the input path. -#' -#' @param data SparkDataFrame for training -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or -#' antitonic/decreasing (FALSE) -#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column -#' (default: 0), no effect otherwise -#' @param weightCol The weight column name. -#' @param ... additional arguments passed to the method. -#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model -#' @rdname spark.isoreg -#' @aliases spark.isoreg,SparkDataFrame,formula-method -#' @name spark.isoreg -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0), -#' list(5.0, 3.0), list(1.0, 4.0)) -#' df <- createDataFrame(data, c("label", "feature")) -#' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE) -#' # return model boundaries and prediction as lists -#' result <- summary(model, df) -#' # prediction based on fitted model -#' predict_data <- list(list(-2.0), list(-1.0), list(0.5), -#' list(0.75), list(1.0), list(2.0), list(9.0)) -#' predict_df <- createDataFrame(predict_data, c("feature")) -#' # get prediction column -#' predict_result <- collect(select(predict(model, predict_df), "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.isoreg since 2.1.0 -setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { - formula <- paste0(deparse(formula), collapse = "") - - if (is.null(weightCol)) { - weightCol <- "" - } - - jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", - data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) - new("IsotonicRegressionModel", jobj = jobj) - }) - -# Predicted values based on an isotonicRegression model - -#' @param object a fitted IsotonicRegressionModel -#' @param newData SparkDataFrame for testing -#' @return \code{predict} returns a SparkDataFrame containing predicted values -#' @rdname spark.isoreg -#' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method -#' @export -#' @note predict(IsotonicRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "IsotonicRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Get the summary of an IsotonicRegressionModel model - -#' @return \code{summary} returns the model's boundaries and prediction as lists -#' @rdname spark.isoreg -#' @aliases summary,IsotonicRegressionModel-method -#' @export -#' @note summary(IsotonicRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "IsotonicRegressionModel"), - function(object) { - jobj <- object@jobj - boundaries <- callJMethod(jobj, "boundaries") - predictions <- callJMethod(jobj, "predictions") - list(boundaries = boundaries, predictions = predictions) - }) - -#' K-Means Clustering Model -#' -#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans(). -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' Note that the response variable of formula is empty in spark.kmeans. -#' @param k number of centers. -#' @param maxIter maximum iteration number. -#' @param initMode the initialization algorithm choosen to fit the model. -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.kmeans} returns a fitted k-means model. -#' @rdname spark.kmeans -#' @aliases spark.kmeans,SparkDataFrame,formula-method -#' @name spark.kmeans -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") -#' summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.kmeans since 2.0.0 -#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} -setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) { - formula <- paste(deparse(formula), collapse = "") - initMode <- match.arg(initMode) - jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, - as.integer(k), as.integer(maxIter), initMode) - new("KMeansModel", jobj = jobj) - }) - -#' Get fitted result from a k-means model -#' -#' Get fitted result from a k-means model, similarly to R's fitted(). -#' Note: A saved-loaded model does not support this method. -#' -#' @param object a fitted k-means model. -#' @param method type of fitted results, \code{"centers"} for cluster centers -#' or \code{"classes"} for assigned classes. -#' @param ... additional argument(s) passed to the method. -#' @return \code{fitted} returns a SparkDataFrame containing fitted values. -#' @rdname fitted -#' @export -#' @examples -#' \dontrun{ -#' model <- spark.kmeans(trainingData, ~ ., 2) -#' fitted.model <- fitted(model) -#' showDF(fitted.model) -#'} -#' @note fitted since 2.0.0 -setMethod("fitted", signature(object = "KMeansModel"), - function(object, method = c("centers", "classes")) { - method <- match.arg(method) - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - if (is.loaded) { - stop("Saved-loaded k-means model does not support 'fitted' method") - } else { - dataFrame(callJMethod(jobj, "fitted", method)) - } - }) - -# Get the summary of a k-means model - -#' @param object a fitted k-means model. -#' @return \code{summary} returns the model's coefficients, size and cluster. -#' @rdname spark.kmeans -#' @export -#' @note summary(KMeansModel) since 2.0.0 -setMethod("summary", signature(object = "KMeansModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - features <- callJMethod(jobj, "features") - coefficients <- callJMethod(jobj, "coefficients") - k <- callJMethod(jobj, "k") - size <- callJMethod(jobj, "size") - coefficients <- t(matrix(coefficients, ncol = k)) - colnames(coefficients) <- unlist(features) - rownames(coefficients) <- 1:k - cluster <- if (is.loaded) { - NULL - } else { - dataFrame(callJMethod(jobj, "cluster")) - } - list(coefficients = coefficients, size = size, - cluster = cluster, is.loaded = is.loaded) - }) - -# Predicted values based on a k-means model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on a k-means model. -#' @rdname spark.kmeans -#' @export -#' @note predict(KMeansModel) since 2.0.0 -setMethod("predict", signature(object = "KMeansModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' Logistic Regression Model -#' -#' Fits an logistic regression model against a Spark DataFrame. It supports "binomial": Binary logistic regression -#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. -#' Users can print, make predictions on the produced model and save the model to the input path. -#' -#' @param data SparkDataFrame for training -#' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param regParam the regularization parameter. Default is 0.0. -#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. -#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination -#' of L1 and L2. Default is 0.0 which is an L2 penalty. -#' @param maxIter maximum iteration number. -#' @param tol convergence tolerance of iterations. -#' @param fitIntercept whether to fit an intercept term. Default is TRUE. -#' @param family the name of family which is a description of the label distribution to be used in the model. -#' Supported options: Default is "auto". -#' \itemize{ -#' \item{"auto": Automatically select the family based on the number of classes: -#' If number of classes == 1 || number of classes == 2, set to "binomial". -#' Else, set to "multinomial".} -#' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} -#' } -#' @param standardization whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. -#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 -#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 -#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of -#' predicting each class. Array must have length equal to the number of classes, with values > 0, -#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. Default is 0.5. -#' @param weightCol The weight column name. -#' @param aggregationDepth depth for treeAggregate (>= 2). If the dimensions of features or the number of partitions -#' are large, this param could be adjusted to a larger size. Default is 2. -#' @param probabilityCol column name for predicted class conditional probabilities. Default is "probability". -#' @param ... additional arguments passed to the method. -#' @return \code{spark.logit} returns a fitted logistic regression model -#' @rdname spark.logit -#' @aliases spark.logit,SparkDataFrame,formula-method -#' @name spark.logit -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' # binary logistic regression -#' label <- c(1.0, 1.0, 1.0, 0.0, 0.0) -#' feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) -#' binary_data <- as.data.frame(cbind(label, feature)) -#' binary_df <- createDataFrame(binary_data) -#' blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) -#' blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) -#' -#' # summary of binary logistic regression -#' blr_summary <- summary(blr_model) -#' blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(blr_model, path) -#' -#' # can also read back the saved model and predict -#' # Note that summary deos not work on loaded model -#' savedModel <- read.ml(path) -#' blr_predict2 <- collect(select(predict(savedModel, binary_df), "prediction")) -#' -#' # multinomial logistic regression -#' -#' label <- c(0.0, 1.0, 2.0, 0.0, 0.0) -#' feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) -#' feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) -#' feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) -#' feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) -#' data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) -#' df <- createDataFrame(data) -#' -#' # Note that summary of multinomial logistic regression is not implemented yet -#' model <- spark.logit(df, label ~ ., family = "multinomial", thresholds = c(0, 1, 1)) -#' predict1 <- collect(select(predict(model, df), "prediction")) -#' } -#' @note spark.logit since 2.1.0 -setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, - tol = 1E-6, fitIntercept = TRUE, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL, aggregationDepth = 2, - probabilityCol = "probability") { - formula <- paste0(deparse(formula), collapse = "") - - if (is.null(weightCol)) { - weightCol <- "" - } - - jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", - data@sdf, formula, as.numeric(regParam), - as.numeric(elasticNetParam), as.integer(maxIter), - as.numeric(tol), as.logical(fitIntercept), - as.character(family), as.logical(standardization), - as.array(thresholds), as.character(weightCol), - as.integer(aggregationDepth), as.character(probabilityCol)) - new("LogisticRegressionModel", jobj = jobj) - }) - -# Predicted values based on an LogisticRegressionModel model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. -#' @rdname spark.logit -#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method -#' @export -#' @note predict(LogisticRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "LogisticRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Get the summary of an LogisticRegressionModel - -#' @param object an LogisticRegressionModel fitted by \code{spark.logit} -#' @return \code{summary} returns the Binary Logistic regression results of a given model as lists. Note that -#' Multinomial logistic regression summary is not available now. -#' @rdname spark.logit -#' @aliases summary,LogisticRegressionModel-method -#' @export -#' @note summary(LogisticRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "LogisticRegressionModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - - if (is.loaded) { - stop("Loaded model doesn't have training summary.") - } - - roc <- dataFrame(callJMethod(jobj, "roc")) - - areaUnderROC <- callJMethod(jobj, "areaUnderROC") - - pr <- dataFrame(callJMethod(jobj, "pr")) - - fMeasureByThreshold <- dataFrame(callJMethod(jobj, "fMeasureByThreshold")) - - precisionByThreshold <- dataFrame(callJMethod(jobj, "precisionByThreshold")) - - recallByThreshold <- dataFrame(callJMethod(jobj, "recallByThreshold")) - - totalIterations <- callJMethod(jobj, "totalIterations") - - objectiveHistory <- callJMethod(jobj, "objectiveHistory") - - list(roc = roc, areaUnderROC = areaUnderROC, pr = pr, - fMeasureByThreshold = fMeasureByThreshold, - precisionByThreshold = precisionByThreshold, - recallByThreshold = recallByThreshold, - totalIterations = totalIterations, objectiveHistory = objectiveHistory) - }) - -#' Multilayer Perceptron Classification Model -#' -#' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' Only categorical data is supported. -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{ -#' Multilayer Perceptron} -#' -#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. -#' @param blockSize blockSize parameter. -#' @param layers integer vector containing the number of nodes for each layer -#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". -#' @param maxIter maximum iteration number. -#' @param tol convergence tolerance of iterations. -#' @param stepSize stepSize parameter. -#' @param seed seed parameter for weights initialization. -#' @param initialWeights initialWeights parameter for weights initialization, it should be a -#' numeric vector. -#' @param ... additional arguments passed to the method. -#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. -#' @rdname spark.mlp -#' @aliases spark.mlp,SparkDataFrame-method -#' @name spark.mlp -#' @seealso \link{read.ml} -#' @export -#' @examples -#' \dontrun{ -#' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") -#' -#' # fit a Multilayer Perceptron Classification Model -#' model <- spark.mlp(df, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", -#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, -#' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.mlp since 2.1.0 -setMethod("spark.mlp", signature(data = "SparkDataFrame"), - function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, - tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { - if (is.null(layers)) { - stop ("layers must be a integer vector with length > 1.") - } - layers <- as.integer(na.omit(layers)) - if (length(layers) <= 1) { - stop ("layers must be a integer vector with length > 1.") - } - if (!is.null(seed)) { - seed <- as.character(as.integer(seed)) - } - if (!is.null(initialWeights)) { - initialWeights <- as.array(as.numeric(na.omit(initialWeights))) - } - jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", - "fit", data@sdf, as.integer(blockSize), as.array(layers), - as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), seed, initialWeights) - new("MultilayerPerceptronClassificationModel", jobj = jobj) - }) - -# Makes predictions from a model produced by spark.mlp(). - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction". -#' @rdname spark.mlp -#' @aliases predict,MultilayerPerceptronClassificationModel-method -#' @export -#' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 -setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} - -#' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns a list containing \code{labelCount}, \code{layers}, and -#' \code{weights}. For \code{weights}, it is a numeric vector with length equal to -#' the expected given the architecture (i.e., for 8-10-2 network, 100 connection weights). -#' @rdname spark.mlp -#' @export -#' @aliases summary,MultilayerPerceptronClassificationModel-method -#' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 -setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), - function(object) { - jobj <- object@jobj - labelCount <- callJMethod(jobj, "labelCount") - layers <- unlist(callJMethod(jobj, "layers")) - weights <- callJMethod(jobj, "weights") - list(labelCount = labelCount, layers = layers, weights = weights) - }) - -#' Naive Bayes Models -#' -#' \code{spark.naiveBayes} fits a Bernoulli naive Bayes model against a SparkDataFrame. -#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make -#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' Only categorical data is supported. -#' -#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param smoothing smoothing parameter. -#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. -#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. -#' @rdname spark.naiveBayes -#' @aliases spark.naiveBayes,SparkDataFrame,formula-method -#' @name spark.naiveBayes -#' @seealso e1071: \url{https://cran.r-project.org/package=e1071} -#' @export -#' @examples -#' \dontrun{ -#' data <- as.data.frame(UCBAdmissions) -#' df <- createDataFrame(data) -#' -#' # fit a Bernoulli naive Bayes model -#' model <- spark.naiveBayes(df, Admit ~ Gender + Dept, smoothing = 0) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.naiveBayes since 2.0.0 -setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, smoothing = 1.0) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", - formula, data@sdf, smoothing) - new("NaiveBayesModel", jobj = jobj) - }) - -# Saves the Bernoulli naive Bayes model to the input path. - -#' @param path the directory where the model is saved -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.naiveBayes -#' @export -#' @seealso \link{write.ml} -#' @note write.ml(NaiveBayesModel, character) since 2.0.0 -setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Saves the AFT survival regression model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' @rdname spark.survreg -#' @export -#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 -#' @seealso \link{write.ml} -setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Saves the generalized linear model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.glm -#' @export -#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 -setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted MLlib model to the input path - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.kmeans -#' @export -#' @note write.ml(KMeansModel, character) since 2.0.0 -setMethod("write.ml", signature(object = "KMeansModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Saves the Multilayer Perceptron Classification Model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.mlp -#' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method -#' @export -#' @seealso \link{write.ml} -#' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", - path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted IsotonicRegressionModel to the input path - -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.isoreg -#' @aliases write.ml,IsotonicRegressionModel,character-method -#' @export -#' @note write.ml(IsotonicRegression, character) since 2.1.0 -setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted LogisticRegressionModel to the input path - -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @rdname spark.logit -#' @aliases write.ml,LogisticRegressionModel,character-method -#' @export -#' @note write.ml(LogisticRegression, character) since 2.1.0 -setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Save fitted MLlib model to the input path - -#' @param path the directory where the model is saved. -#' @param overwrite overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @aliases write.ml,GaussianMixtureModel,character-method -#' @rdname spark.gaussianMixture -#' @export -#' @note write.ml(GaussianMixtureModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' Load a fitted MLlib model from the input path. -#' -#' @param path path of the model to read. -#' @return A fitted MLlib model. -#' @rdname read.ml -#' @name read.ml -#' @export -#' @seealso \link{write.ml} -#' @examples -#' \dontrun{ -#' path <- "path/to/model" -#' model <- read.ml(path) -#' } -#' @note read.ml since 2.0.0 -read.ml <- function(path) { - path <- suppressWarnings(normalizePath(path)) - jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) - if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { - new("NaiveBayesModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { - new("AFTSurvivalRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) { - new("GeneralizedLinearRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { - new("KMeansModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) { - new("LDAModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) { - new("MultilayerPerceptronClassificationModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { - new("IsotonicRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { - new("GaussianMixtureModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { - new("ALSModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { - new("LogisticRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { - new("RandomForestRegressionModel", jobj = jobj) - } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { - new("RandomForestClassificationModel", jobj = jobj) - } else { - stop("Unsupported model: ", jobj) - } -} - -#' Accelerated Failure Time (AFT) Survival Regression Model -#' -#' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on -#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model, -#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to -#' save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', ':', '+', and '-'. -#' Note that operator '.' is not supported currently. -#' @return \code{spark.survreg} returns a fitted AFT survival regression model. -#' @rdname spark.survreg -#' @seealso survival: \url{https://cran.r-project.org/package=survival} -#' @export -#' @examples -#' \dontrun{ -#' df <- createDataFrame(ovarian) -#' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx) -#' -#' # get a summary of the model -#' summary(model) -#' -#' # make predictions -#' predicted <- predict(model, df) -#' showDF(predicted) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.survreg since 2.0.0 -setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) - new("AFTSurvivalRegressionModel", jobj = jobj) - }) - -#' Latent Dirichlet Allocation -#' -#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call -#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute -#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new -#' data and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' @param data A SparkDataFrame for training -#' @param features Features column name, default "features". Either libSVM-format column or -#' character-format column is valid. -#' @param k Number of topics, default 10 -#' @param maxIter Maximum iterations, default 20 -#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online" -#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in -#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05 -#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for -#' the prior placed on topic distributions over terms, default -1 to set automatically on the -#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size -#' numeric is accepted. -#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the -#' prior placed on documents distributions over topics (\code{theta}), default -1 to set -#' automatically on the Spark side. Use \code{summary} to retrieve the effective -#' docConcentration. Only 1-size or \code{k}-size numeric is accepted. -#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the -#' parameter if libSVM-format column is used as the features column. -#' @param maxVocabSize maximum vocabulary size, default 1 << 18 -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model -#' @rdname spark.lda -#' @aliases spark.lda,SparkDataFrame-method -#' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} -#' @export -#' @examples -#' \dontrun{ -#' # nolint start -#' # An example "path/to/file" can be -#' # paste0(Sys.getenv("SPARK_HOME"), "/data/mllib/sample_lda_libsvm_data.txt") -#' # nolint end -#' text <- read.df("path/to/file", source = "libsvm") -#' model <- spark.lda(data = text, optimizer = "em") -#' -#' # get a summary of the model -#' summary(model) -#' -#' # compute posterior probabilities -#' posterior <- spark.posterior(model, text) -#' showDF(posterior) -#' -#' # compute perplexity -#' perplexity <- spark.perplexity(model, text) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.lda since 2.1.0 -setMethod("spark.lda", signature(data = "SparkDataFrame"), - function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"), - subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1, - customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) { - optimizer <- match.arg(optimizer) - jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features, - as.integer(k), as.integer(maxIter), optimizer, - as.numeric(subsamplingRate), topicConcentration, - as.array(docConcentration), as.array(customizedStopWords), - maxVocabSize) - new("LDAModel", jobj = jobj) - }) - -# Returns a summary of the AFT survival regression model produced by spark.survreg, -# similarly to R's summary(). - -#' @param object a fitted AFT survival regression model. -#' @return \code{summary} returns a list containing the model's coefficients, -#' intercept and log(scale) -#' @rdname spark.survreg -#' @export -#' @note summary(AFTSurvivalRegressionModel) since 2.0.0 -setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), - function(object) { - jobj <- object@jobj - features <- callJMethod(jobj, "rFeatures") - coefficients <- callJMethod(jobj, "rCoefficients") - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Value") - rownames(coefficients) <- unlist(features) - list(coefficients = coefficients) - }) - -# Makes predictions from an AFT survival regression model or a model produced by -# spark.survreg, similarly to R package survival's predict. - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted values -#' on the original scale of the data (mean predicted value at scale = 1.0). -#' @rdname spark.survreg -#' @export -#' @note predict(AFTSurvivalRegressionModel) since 2.0.0 -setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' Multivariate Gaussian Mixture Model (GMM) -#' -#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's -#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model, -#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} -#' to save/load fitted models. -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. -#' Note that the response variable of formula is empty in spark.gaussianMixture. -#' @param k number of independent Gaussians in the mixture model. -#' @param maxIter maximum iteration number. -#' @param tol the convergence tolerance. -#' @param ... additional arguments passed to the method. -#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method -#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model. -#' @rdname spark.gaussianMixture -#' @name spark.gaussianMixture -#' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} -#' @export -#' @examples -#' \dontrun{ -#' sparkR.session() -#' library(mvtnorm) -#' set.seed(100) -#' a <- rmvnorm(4, c(0, 0)) -#' b <- rmvnorm(6, c(3, 4)) -#' data <- rbind(a, b) -#' df <- createDataFrame(as.data.frame(data)) -#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2) -#' summary(model) -#' -#' # fitted values on training data -#' fitted <- predict(model, df) -#' head(select(fitted, "V1", "prediction")) -#' -#' # save fitted model to input path -#' path <- "path/to/model" -#' write.ml(model, path) -#' -#' # can also read back the saved model and print -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' } -#' @note spark.gaussianMixture since 2.1.0 -#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} -setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, k = 2, maxIter = 100, tol = 0.01) { - formula <- paste(deparse(formula), collapse = "") - jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf, - formula, as.integer(k), as.integer(maxIter), as.numeric(tol)) - new("GaussianMixtureModel", jobj = jobj) - }) - -# Get the summary of a multivariate gaussian mixture model - -#' @param object a fitted gaussian mixture model. -#' @return \code{summary} returns the model's lambda, mu, sigma and posterior. -#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method -#' @rdname spark.gaussianMixture -#' @export -#' @note summary(GaussianMixtureModel) since 2.1.0 -setMethod("summary", signature(object = "GaussianMixtureModel"), - function(object) { - jobj <- object@jobj - is.loaded <- callJMethod(jobj, "isLoaded") - lambda <- unlist(callJMethod(jobj, "lambda")) - muList <- callJMethod(jobj, "mu") - sigmaList <- callJMethod(jobj, "sigma") - k <- callJMethod(jobj, "k") - dim <- callJMethod(jobj, "dim") - mu <- c() - for (i in 1 : k) { - start <- (i - 1) * dim + 1 - end <- i * dim - mu[[i]] <- unlist(muList[start : end]) - } - sigma <- c() - for (i in 1 : k) { - start <- (i - 1) * dim * dim + 1 - end <- i * dim * dim - sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim)) - } - posterior <- if (is.loaded) { - NULL - } else { - dataFrame(callJMethod(jobj, "posterior")) - } - list(lambda = lambda, mu = mu, sigma = sigma, - posterior = posterior, is.loaded = is.loaded) - }) - -# Predicted values based on a gaussian mixture model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named -#' "prediction". -#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method -#' @rdname spark.gaussianMixture -#' @export -#' @note predict(GaussianMixtureModel) since 2.1.0 -setMethod("predict", signature(object = "GaussianMixtureModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' Alternating Least Squares (ALS) for Collaborative Filtering -#' -#' \code{spark.als} learns latent factors in collaborative filtering via alternating least -#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict} -#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. -#' -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib: -#' Collaborative Filtering}. -#' -#' @param data a SparkDataFrame for training. -#' @param ratingCol column name for ratings. -#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. -#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. -#' @param rank rank of the matrix factorization (> 0). -#' @param reg regularization parameter (>= 0). -#' @param maxIter maximum number of iterations (>= 0). -#' @param nonnegative logical value indicating whether to apply nonnegativity constraints. -#' @param implicitPrefs logical value indicating whether to use implicit preference. -#' @param alpha alpha parameter in the implicit preference formulation (>= 0). -#' @param seed integer seed for random number generation. -#' @param numUserBlocks number of user blocks used to parallelize computation (> 0). -#' @param numItemBlocks number of item blocks used to parallelize computation (> 0). -#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.als} returns a fitted ALS model -#' @rdname spark.als -#' @aliases spark.als,SparkDataFrame-method -#' @name spark.als -#' @export -#' @examples -#' \dontrun{ -#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), -#' list(2, 1, 1.0), list(2, 2, 5.0)) -#' df <- createDataFrame(ratings, c("user", "item", "rating")) -#' model <- spark.als(df, "rating", "user", "item") -#' -#' # extract latent factors -#' stats <- summary(model) -#' userFactors <- stats$userFactors -#' itemFactors <- stats$itemFactors -#' -#' # make predictions -#' predicted <- predict(model, df) -#' showDF(predicted) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' -#' # set other arguments -#' modelS <- spark.als(df, "rating", "user", "item", rank = 20, -#' reg = 0.1, nonnegative = TRUE) -#' statsS <- summary(modelS) -#' } -#' @note spark.als since 2.1.0 -setMethod("spark.als", signature(data = "SparkDataFrame"), - function(data, ratingCol = "rating", userCol = "user", itemCol = "item", - rank = 10, reg = 0.1, maxIter = 10, nonnegative = FALSE, - implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, - checkpointInterval = 10, seed = 0) { - - if (!is.numeric(rank) || rank <= 0) { - stop("rank should be a positive number.") - } - if (!is.numeric(reg) || reg < 0) { - stop("reg should be a nonnegative number.") - } - if (!is.numeric(maxIter) || maxIter <= 0) { - stop("maxIter should be a positive number.") - } - - jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", - "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), - reg, as.integer(maxIter), implicitPrefs, alpha, nonnegative, - as.integer(numUserBlocks), as.integer(numItemBlocks), - as.integer(checkpointInterval), as.integer(seed)) - new("ALSModel", jobj = jobj) - }) - -# Returns a summary of the ALS model produced by spark.als. - -#' @param object a fitted ALS model. -#' @return \code{summary} returns a list containing the names of the user column, -#' the item column and the rating column, the estimated user and item factors, -#' rank, regularization parameter and maximum number of iterations used in training. -#' @rdname spark.als -#' @aliases summary,ALSModel-method -#' @export -#' @note summary(ALSModel) since 2.1.0 -setMethod("summary", signature(object = "ALSModel"), - function(object) { - jobj <- object@jobj - user <- callJMethod(jobj, "userCol") - item <- callJMethod(jobj, "itemCol") - rating <- callJMethod(jobj, "ratingCol") - userFactors <- dataFrame(callJMethod(jobj, "userFactors")) - itemFactors <- dataFrame(callJMethod(jobj, "itemFactors")) - rank <- callJMethod(jobj, "rank") - list(user = user, item = item, rating = rating, userFactors = userFactors, - itemFactors = itemFactors, rank = rank) - }) - - -# Makes predictions from an ALS model or a model produced by spark.als. - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted values. -#' @rdname spark.als -#' @aliases predict,ALSModel-method -#' @export -#' @note predict(ALSModel) since 2.1.0 -setMethod("predict", signature(object = "ALSModel"), - function(object, newData) { - predict_internal(object, newData) - }) - - -# Saves the ALS model to the input path. - -#' @param path the directory where the model is saved. -#' @param overwrite logical value indicating whether to overwrite if the output path -#' already exists. Default is FALSE which means throw exception -#' if the output path exists. -#' -#' @rdname spark.als -#' @aliases write.ml,ALSModel,character-method -#' @export -#' @seealso \link{read.ml} -#' @note write.ml(ALSModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "ALSModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' (One-Sample) Kolmogorov-Smirnov Test -#' -#' @description -#' \code{spark.kstest} Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a -#' continuous distribution. -#' -#' By comparing the largest difference between the empirical cumulative -#' distribution of the sample data and the theoretical distribution we can provide a test for the -#' the null hypothesis that the sample data comes from that theoretical distribution. -#' -#' Users can call \code{summary} to obtain a summary of the test, and \code{print.summary.KSTest} -#' to print out a summary result. -#' -#' @param data a SparkDataFrame of user data. -#' @param testCol column name where the test data is from. It should be a column of double type. -#' @param nullHypothesis name of the theoretical distribution tested against. Currently only -#' \code{"norm"} for normal distribution is supported. -#' @param distParams parameters(s) of the distribution. For \code{nullHypothesis = "norm"}, -#' we can provide as a vector the mean and standard deviation of -#' the distribution. If none is provided, then standard normal will be used. -#' If only one is provided, then the standard deviation will be set to be one. -#' @param ... additional argument(s) passed to the method. -#' @return \code{spark.kstest} returns a test result object. -#' @rdname spark.kstest -#' @aliases spark.kstest,SparkDataFrame-method -#' @name spark.kstest -#' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ -#' MLlib: Hypothesis Testing} -#' @export -#' @examples -#' \dontrun{ -#' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) -#' df <- createDataFrame(data) -#' test <- spark.ktest(df, "test", "norm", c(0, 1)) -#' -#' # get a summary of the test result -#' testSummary <- summary(test) -#' testSummary -#' -#' # print out the summary in an organized way -#' print.summary.KSTest(test) -#' } -#' @note spark.kstest since 2.1.0 -setMethod("spark.kstest", signature(data = "SparkDataFrame"), - function(data, testCol = "test", nullHypothesis = c("norm"), distParams = c(0, 1)) { - tryCatch(match.arg(nullHypothesis), - error = function(e) { - msg <- paste("Distribution", nullHypothesis, "is not supported.") - stop(msg) - }) - if (nullHypothesis == "norm") { - distParams <- as.numeric(distParams) - mu <- ifelse(length(distParams) < 1, 0, distParams[1]) - sigma <- ifelse(length(distParams) < 2, 1, distParams[2]) - jobj <- callJStatic("org.apache.spark.ml.r.KSTestWrapper", - "test", data@sdf, testCol, nullHypothesis, - as.array(c(mu, sigma))) - new("KSTest", jobj = jobj) - } -}) - -# Get the summary of Kolmogorov-Smirnov (KS) Test. -#' @param object test result object of KSTest by \code{spark.kstest}. -#' @return \code{summary} returns a list containing the p-value, test statistic computed for the -#' test, the null hypothesis with its parameters tested against -#' and degrees of freedom of the test. -#' @rdname spark.kstest -#' @aliases summary,KSTest-method -#' @export -#' @note summary(KSTest) since 2.1.0 -setMethod("summary", signature(object = "KSTest"), - function(object) { - jobj <- object@jobj - pValue <- callJMethod(jobj, "pValue") - statistic <- callJMethod(jobj, "statistic") - nullHypothesis <- callJMethod(jobj, "nullHypothesis") - distName <- callJMethod(jobj, "distName") - distParams <- unlist(callJMethod(jobj, "distParams")) - degreesOfFreedom <- callJMethod(jobj, "degreesOfFreedom") - - ans <- list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis, - nullHypothesis.name = distName, nullHypothesis.parameters = distParams, - degreesOfFreedom = degreesOfFreedom, jobj = jobj) - class(ans) <- "summary.KSTest" - ans - }) - -# Prints the summary of KSTest - -#' @rdname spark.kstest -#' @param x summary object of KSTest returned by \code{summary}. -#' @export -#' @note print.summary.KSTest since 2.1.0 -print.summary.KSTest <- function(x, ...) { - jobj <- x$jobj - summaryStr <- callJMethod(jobj, "summary") - cat(summaryStr, "\n") - invisible(x) -} - -#' Random Forest Model for Regression and Classification -#' -#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on -#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest -#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to -#' save/load fitted models. -#' For more details, see -#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{Random Forest} -#' -#' @param data a SparkDataFrame for training. -#' @param formula a symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', ':', '+', and '-'. -#' @param type type of model, one of "regression" or "classification", to fit -#' @param maxDepth Maximum depth of the tree (>= 0). (default = 5) -#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing -#' how to split on features at each node. More bins give higher granularity. Must be -#' >= 2 and >= number of categories in any categorical feature. (default = 32) -#' @param numTrees Number of trees to train (>= 1). -#' @param impurity Criterion used for information gain calculation. -#' For regression, must be "variance". For classification, must be one of -#' "entropy" and "gini". (default = gini) -#' @param minInstancesPerNode Minimum number of instances each child must have after split. -#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. -#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). -#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. -#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. -#' @param seed integer seed for random number generation. -#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in -#' range (0, 1]. (default = 1.0) -#' @param probabilityCol column name for predicted class conditional probabilities, only for -#' classification. (default = "probability") -#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. -#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with -#' nodes. -#' @param ... additional arguments passed to the method. -#' @aliases spark.randomForest,SparkDataFrame,formula-method -#' @return \code{spark.randomForest} returns a fitted Random Forest model. -#' @rdname spark.randomForest -#' @name spark.randomForest -#' @export -#' @examples -#' \dontrun{ -#' # fit a Random Forest Regression Model -#' df <- createDataFrame(longley) -#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) -#' -#' # get the summary of the model -#' summary(model) -#' -#' # make predictions -#' predictions <- predict(model, df) -#' -#' # save and load the model -#' path <- "path/to/model" -#' write.ml(model, path) -#' savedModel <- read.ml(path) -#' summary(savedModel) -#' -#' # fit a Random Forest Classification Model -#' df <- createDataFrame(iris) -#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") -#' } -#' @note spark.randomForest since 2.1.0 -setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula, type = c("regression", "classification"), - maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, - minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, - featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, - probabilityCol = "probability", maxMemoryInMB = 256, cacheNodeIds = FALSE) { - type <- match.arg(type) - formula <- paste(deparse(formula), collapse = "") - if (!is.null(seed)) { - seed <- as.character(as.integer(seed)) - } - switch(type, - regression = { - if (is.null(impurity)) impurity <- "variance" - impurity <- match.arg(impurity, "variance") - jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", - "fit", data@sdf, formula, as.integer(maxDepth), - as.integer(maxBins), as.integer(numTrees), - impurity, as.integer(minInstancesPerNode), - as.numeric(minInfoGain), as.integer(checkpointInterval), - as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) - new("RandomForestRegressionModel", jobj = jobj) - }, - classification = { - if (is.null(impurity)) impurity <- "gini" - impurity <- match.arg(impurity, c("gini", "entropy")) - jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", - "fit", data@sdf, formula, as.integer(maxDepth), - as.integer(maxBins), as.integer(numTrees), - impurity, as.integer(minInstancesPerNode), - as.numeric(minInfoGain), as.integer(checkpointInterval), - as.character(featureSubsetStrategy), seed, - as.numeric(subsamplingRate), as.character(probabilityCol), - as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) - new("RandomForestClassificationModel", jobj = jobj) - } - ) - }) - -# Makes predictions from a Random Forest Regression model or Classification model - -#' @param newData a SparkDataFrame for testing. -#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named -#' "prediction" -#' @rdname spark.randomForest -#' @aliases predict,RandomForestRegressionModel-method -#' @export -#' @note predict(randomForestRegressionModel) since 2.1.0 -setMethod("predict", signature(object = "RandomForestRegressionModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -#' @rdname spark.randomForest -#' @aliases predict,RandomForestClassificationModel-method -#' @export -#' @note predict(randomForestClassificationModel) since 2.1.0 -setMethod("predict", signature(object = "RandomForestClassificationModel"), - function(object, newData) { - predict_internal(object, newData) - }) - -# Save the Random Forest Regression or Classification model to the input path. - -#' @param object A fitted Random Forest regression model or classification model -#' @param path The directory where the model is saved -#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE -#' which means throw exception if the output path exists. -#' -#' @aliases write.ml,RandomForestRegressionModel,character-method -#' @rdname spark.randomForest -#' @export -#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -#' @aliases write.ml,RandomForestClassificationModel,character-method -#' @rdname spark.randomForest -#' @export -#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 -setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), - function(object, path, overwrite = FALSE) { - write_internal(object, path, overwrite) - }) - -# Get the summary of an RandomForestRegressionModel model -summary.randomForest <- function(model) { - jobj <- model@jobj - formula <- callJMethod(jobj, "formula") - numFeatures <- callJMethod(jobj, "numFeatures") - features <- callJMethod(jobj, "features") - featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") - numTrees <- callJMethod(jobj, "numTrees") - treeWeights <- callJMethod(jobj, "treeWeights") - list(formula = formula, - numFeatures = numFeatures, - features = features, - featureImportances = featureImportances, - numTrees = numTrees, - treeWeights = treeWeights, - jobj = jobj) -} - -#' @return \code{summary} returns the model's features as lists, depth and number of nodes -#' or number of classes. -#' @rdname spark.randomForest -#' @aliases summary,RandomForestRegressionModel-method -#' @export -#' @note summary(RandomForestRegressionModel) since 2.1.0 -setMethod("summary", signature(object = "RandomForestRegressionModel"), - function(object) { - ans <- summary.randomForest(object) - class(ans) <- "summary.RandomForestRegressionModel" - ans - }) - -# Get the summary of an RandomForestClassificationModel model - -#' @rdname spark.randomForest -#' @aliases summary,RandomForestClassificationModel-method -#' @export -#' @note summary(RandomForestClassificationModel) since 2.1.0 -setMethod("summary", signature(object = "RandomForestClassificationModel"), - function(object) { - ans <- summary.randomForest(object) - class(ans) <- "summary.RandomForestClassificationModel" - ans - }) - -# Prints the summary of Random Forest Regression Model -print.summary.randomForest <- function(x) { - jobj <- x$jobj - cat("Formula: ", x$formula) - cat("\nNumber of features: ", x$numFeatures) - cat("\nFeatures: ", unlist(x$features)) - cat("\nFeature importances: ", x$featureImportances) - cat("\nNumber of trees: ", x$numTrees) - cat("\nTree weights: ", unlist(x$treeWeights)) - - summaryStr <- callJMethod(jobj, "summary") - cat("\n", summaryStr, "\n") - invisible(x) -} - -#' @param x summary object of Random Forest regression model or classification model -#' returned by \code{summary}. -#' @rdname spark.randomForest -#' @export -#' @note print.summary.RandomForestRegressionModel since 2.1.0 -print.summary.RandomForestRegressionModel <- function(x, ...) { - print.summary.randomForest(x) -} - -# Prints the summary of Random Forest Classification Model - -#' @rdname spark.randomForest -#' @export -#' @note print.summary.RandomForestClassificationModel since 2.1.0 -print.summary.RandomForestClassificationModel <- function(x, ...) { - print.summary.randomForest(x) -} diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R new file mode 100644 index 0000000000000..4db9cc30fb0c1 --- /dev/null +++ b/R/pkg/R/mllib_classification.R @@ -0,0 +1,553 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_regression.R: Provides methods for MLlib classification algorithms +# (except for tree-based algorithms) integration + +#' S4 class that represents an LinearSVCModel +#' +#' @param jobj a Java object reference to the backing Scala LinearSVCModel +#' @export +#' @note LinearSVCModel since 2.2.0 +setClass("LinearSVCModel", representation(jobj = "jobj")) + +#' S4 class that represents an LogisticRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala LogisticRegressionModel +#' @export +#' @note LogisticRegressionModel since 2.1.0 +setClass("LogisticRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a MultilayerPerceptronClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala MultilayerPerceptronClassifierWrapper +#' @export +#' @note MultilayerPerceptronClassificationModel since 2.1.0 +setClass("MultilayerPerceptronClassificationModel", representation(jobj = "jobj")) + +#' S4 class that represents a NaiveBayesModel +#' +#' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper +#' @export +#' @note NaiveBayesModel since 2.0.0 +setClass("NaiveBayesModel", representation(jobj = "jobj")) + +#' linear SVM Model +#' +#' Fits an linear SVM model against a SparkDataFrame. It is a binary classifier, similar to svm in glmnet package +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training. +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param regParam The regularization parameter. +#' @param maxIter Maximum iteration number. +#' @param tol Convergence tolerance of iterations. +#' @param standardization Whether to standardize the training features before fitting the model. The coefficients +#' of models will be always returned on the original scale, so it will be transparent for +#' users. Note that with/without standardization, the models should be always converged +#' to the same solution when no regularization is applied. +#' @param threshold The threshold in binary classification, in range [0, 1]. +#' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.svmLinear} returns a fitted linear SVM model. +#' @rdname spark.svmLinear +#' @aliases spark.svmLinear,SparkDataFrame,formula-method +#' @name spark.svmLinear +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.svmLinear(training, Survived ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and predict +#' # Note that summary deos not work on loaded model +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.svmLinear since 2.2.0 +setMethod("spark.svmLinear", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, regParam = 0.0, maxIter = 100, tol = 1E-6, standardization = TRUE, + threshold = 0.0, weightCol = NULL, aggregationDepth = 2) { + formula <- paste(deparse(formula), collapse = "") + + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) + } + + jobj <- callJStatic("org.apache.spark.ml.r.LinearSVCWrapper", "fit", + data@sdf, formula, as.numeric(regParam), as.integer(maxIter), + as.numeric(tol), as.logical(standardization), as.numeric(threshold), + weightCol, as.integer(aggregationDepth)) + new("LinearSVCModel", jobj = jobj) + }) + +# Predicted values based on an LinearSVCModel model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on an LinearSVCModel. +#' @rdname spark.svmLinear +#' @aliases predict,LinearSVCModel,SparkDataFrame-method +#' @export +#' @note predict(LinearSVCModel) since 2.2.0 +setMethod("predict", signature(object = "LinearSVCModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Get the summary of an LinearSVCModel + +#' @param object an LinearSVCModel fitted by \code{spark.svmLinear}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{coefficients} (coefficients of the fitted model), +#' \code{intercept} (intercept of the fitted model), \code{numClasses} (number of classes), +#' \code{numFeatures} (number of features). +#' @rdname spark.svmLinear +#' @aliases summary,LinearSVCModel-method +#' @export +#' @note summary(LinearSVCModel) since 2.2.0 +setMethod("summary", signature(object = "LinearSVCModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "coefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(unlist(coefficients), ncol = nCol) + intercept <- callJMethod(jobj, "intercept") + numClasses <- callJMethod(jobj, "numClasses") + numFeatures <- callJMethod(jobj, "numFeatures") + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) + } + rownames(coefficients) <- unlist(features) + list(coefficients = coefficients, intercept = intercept, + numClasses = numClasses, numFeatures = numFeatures) + }) + +# Save fitted LinearSVCModel to the input path + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.svmLinear +#' @aliases write.ml,LinearSVCModel,character-method +#' @export +#' @note write.ml(LogisticRegression, character) since 2.2.0 +setMethod("write.ml", signature(object = "LinearSVCModel", path = "character"), +function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) +}) + +#' Logistic Regression Model +#' +#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression +#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training. +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param regParam the regularization parameter. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. +#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination +#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param maxIter maximum iteration number. +#' @param tol convergence tolerance of iterations. +#' @param family the name of family which is a description of the label distribution to be used in the model. +#' Supported options: +#' \itemize{ +#' \item{"auto": Automatically select the family based on the number of classes: +#' If number of classes == 1 || number of classes == 2, set to "binomial". +#' Else, set to "multinomial".} +#' \item{"binomial": Binary logistic regression with pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' } +#' @param standardization whether to standardize the training features before fitting the model. The coefficients +#' of models will be always returned on the original scale, so it will be transparent for +#' users. Note that with/without standardization, the models should be always converged +#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 +#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 +#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with +#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of +#' predicting each class. Array must have length equal to the number of classes, with values > 0, +#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p +#' is the original probability of that class and t is the class's threshold. +#' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.logit} returns a fitted logistic regression model. +#' @rdname spark.logit +#' @aliases spark.logit,SparkDataFrame,formula-method +#' @name spark.logit +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' # binary logistic regression +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.logit(training, Survived ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, training) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and predict +#' # Note that summary deos not work on loaded model +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # multinomial logistic regression +#' +#' model <- spark.logit(training, Class ~ ., regParam = 0.5) +#' summary <- summary(model) +#' +#' } +#' @note spark.logit since 2.1.0 +setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, + tol = 1E-6, family = "auto", standardization = TRUE, + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { + formula <- paste(deparse(formula), collapse = "") + + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) + } + + jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", + data@sdf, formula, as.numeric(regParam), + as.numeric(elasticNetParam), as.integer(maxIter), + as.numeric(tol), as.character(family), + as.logical(standardization), as.array(thresholds), + weightCol, as.integer(aggregationDepth)) + new("LogisticRegressionModel", jobj = jobj) + }) + +# Get the summary of an LogisticRegressionModel + +#' @param object an LogisticRegressionModel fitted by \code{spark.logit}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{coefficients} (coefficients matrix of the fitted model). +#' @rdname spark.logit +#' @aliases summary,LogisticRegressionModel-method +#' @export +#' @note summary(LogisticRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "LogisticRegressionModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + labels <- callJMethod(jobj, "labels") + coefficients <- callJMethod(jobj, "rCoefficients") + nCol <- length(coefficients) / length(features) + coefficients <- matrix(unlist(coefficients), ncol = nCol) + # If nCol == 1, means this is a binomial logistic regression model with pivoting. + # Otherwise, it's a multinomial logistic regression model without pivoting. + if (nCol == 1) { + colnames(coefficients) <- c("Estimate") + } else { + colnames(coefficients) <- unlist(labels) + } + rownames(coefficients) <- unlist(features) + + list(coefficients = coefficients) + }) + +# Predicted values based on an LogisticRegressionModel model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on an LogisticRegressionModel. +#' @rdname spark.logit +#' @aliases predict,LogisticRegressionModel,SparkDataFrame-method +#' @export +#' @note predict(LogisticRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "LogisticRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save fitted LogisticRegressionModel to the input path + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.logit +#' @aliases write.ml,LogisticRegressionModel,character-method +#' @export +#' @note write.ml(LogisticRegression, character) since 2.1.0 +setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Multilayer Perceptron Classification Model +#' +#' \code{spark.mlp} fits a multi-layer perceptron neural network model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' Only categorical data is supported. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html}{ +#' Multilayer Perceptron} +#' +#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param blockSize blockSize parameter. +#' @param layers integer vector containing the number of nodes for each layer. +#' @param solver solver parameter, supported options: "gd" (minibatch gradient descent) or "l-bfgs". +#' @param maxIter maximum iteration number. +#' @param tol convergence tolerance of iterations. +#' @param stepSize stepSize parameter. +#' @param seed seed parameter for weights initialization. +#' @param initialWeights initialWeights parameter for weights initialization, it should be a +#' numeric vector. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. +#' @rdname spark.mlp +#' @aliases spark.mlp,SparkDataFrame,formula-method +#' @name spark.mlp +#' @seealso \link{read.ml} +#' @export +#' @examples +#' \dontrun{ +#' df <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") +#' +#' # fit a Multilayer Perceptron Classification Model +#' model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 3), solver = "l-bfgs", +#' maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, +#' initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.mlp since 2.1.0 +setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, + tol = 1E-6, stepSize = 0.03, seed = NULL, initialWeights = NULL) { + formula <- paste(deparse(formula), collapse = "") + if (is.null(layers)) { + stop ("layers must be a integer vector with length > 1.") + } + layers <- as.integer(na.omit(layers)) + if (length(layers) <= 1) { + stop ("layers must be a integer vector with length > 1.") + } + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + if (!is.null(initialWeights)) { + initialWeights <- as.array(as.numeric(na.omit(initialWeights))) + } + jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", + "fit", data@sdf, formula, as.integer(blockSize), as.array(layers), + as.character(solver), as.integer(maxIter), as.numeric(tol), + as.numeric(stepSize), seed, initialWeights) + new("MultilayerPerceptronClassificationModel", jobj = jobj) + }) + +# Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} + +#' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{numOfInputs} (number of inputs), \code{numOfOutputs} +#' (number of outputs), \code{layers} (array of layer sizes including input +#' and output layers), and \code{weights} (the weights of layers). +#' For \code{weights}, it is a numeric vector with length equal to the expected +#' given the architecture (i.e., for 8-10-2 network, 112 connection weights). +#' @rdname spark.mlp +#' @export +#' @aliases summary,MultilayerPerceptronClassificationModel-method +#' @note summary(MultilayerPerceptronClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel"), + function(object) { + jobj <- object@jobj + layers <- unlist(callJMethod(jobj, "layers")) + numOfInputs <- head(layers, n = 1) + numOfOutputs <- tail(layers, n = 1) + weights <- callJMethod(jobj, "weights") + list(numOfInputs = numOfInputs, numOfOutputs = numOfOutputs, + layers = layers, weights = weights) + }) + +# Makes predictions from a model produced by spark.mlp(). + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.mlp +#' @aliases predict,MultilayerPerceptronClassificationModel-method +#' @export +#' @note predict(MultilayerPerceptronClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the Multilayer Perceptron Classification Model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.mlp +#' @aliases write.ml,MultilayerPerceptronClassificationModel,character-method +#' @export +#' @seealso \link{write.ml} +#' @note write.ml(MultilayerPerceptronClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationModel", + path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Naive Bayes Models +#' +#' \code{spark.naiveBayes} fits a Bernoulli naive Bayes model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' Only categorical data is supported. +#' +#' @param data a \code{SparkDataFrame} of observations and labels for model fitting. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param smoothing smoothing parameter. +#' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. +#' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. +#' @rdname spark.naiveBayes +#' @aliases spark.naiveBayes,SparkDataFrame,formula-method +#' @name spark.naiveBayes +#' @seealso e1071: \url{https://cran.r-project.org/package=e1071} +#' @export +#' @examples +#' \dontrun{ +#' data <- as.data.frame(UCBAdmissions) +#' df <- createDataFrame(data) +#' +#' # fit a Bernoulli naive Bayes model +#' model <- spark.naiveBayes(df, Admit ~ Gender + Dept, smoothing = 0) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.naiveBayes since 2.0.0 +setMethod("spark.naiveBayes", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, smoothing = 1.0) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.NaiveBayesWrapper", "fit", + formula, data@sdf, smoothing) + new("NaiveBayesModel", jobj = jobj) + }) + +# Returns the summary of a naive Bayes model produced by \code{spark.naiveBayes} + +#' @param object a naive Bayes model fitted by \code{spark.naiveBayes}. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{apriori} (the label distribution) and +#' \code{tables} (conditional probabilities given the target label). +#' @rdname spark.naiveBayes +#' @export +#' @note summary(NaiveBayesModel) since 2.0.0 +setMethod("summary", signature(object = "NaiveBayesModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + labels <- callJMethod(jobj, "labels") + apriori <- callJMethod(jobj, "apriori") + apriori <- t(as.matrix(unlist(apriori))) + colnames(apriori) <- unlist(labels) + tables <- callJMethod(jobj, "tables") + tables <- matrix(tables, nrow = length(labels)) + rownames(tables) <- unlist(labels) + colnames(tables) <- unlist(features) + list(apriori = apriori, tables = tables) + }) + +# Makes predictions from a naive Bayes model or a model produced by spark.naiveBayes(), +# similarly to R package e1071's predict. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.naiveBayes +#' @export +#' @note predict(NaiveBayesModel) since 2.0.0 +setMethod("predict", signature(object = "NaiveBayesModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the Bernoulli naive Bayes model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.naiveBayes +#' @export +#' @seealso \link{write.ml} +#' @note write.ml(NaiveBayesModel, character) since 2.0.0 +setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R new file mode 100644 index 0000000000000..97c9fa1b45840 --- /dev/null +++ b/R/pkg/R/mllib_clustering.R @@ -0,0 +1,634 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_clustering.R: Provides methods for MLlib clustering algorithms integration + +#' S4 class that represents a BisectingKMeansModel +#' +#' @param jobj a Java object reference to the backing Scala BisectingKMeansModel +#' @export +#' @note BisectingKMeansModel since 2.2.0 +setClass("BisectingKMeansModel", representation(jobj = "jobj")) + +#' S4 class that represents a GaussianMixtureModel +#' +#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel +#' @export +#' @note GaussianMixtureModel since 2.1.0 +setClass("GaussianMixtureModel", representation(jobj = "jobj")) + +#' S4 class that represents a KMeansModel +#' +#' @param jobj a Java object reference to the backing Scala KMeansModel +#' @export +#' @note KMeansModel since 2.0.0 +setClass("KMeansModel", representation(jobj = "jobj")) + +#' S4 class that represents an LDAModel +#' +#' @param jobj a Java object reference to the backing Scala LDAWrapper +#' @export +#' @note LDAModel since 2.1.0 +setClass("LDAModel", representation(jobj = "jobj")) + +#' Bisecting K-Means Clustering Model +#' +#' Fits a bisecting k-means clustering model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' Note that the response variable of formula is empty in spark.bisectingKmeans. +#' @param k the desired number of leaf clusters. Must be > 1. +#' The actual number could be smaller if there are no divisible leaf clusters. +#' @param maxIter maximum iteration number. +#' @param seed the random seed. +#' @param minDivisibleClusterSize The minimum number of points (if greater than or equal to 1.0) +#' or the minimum proportion of points (if less than 1.0) of a divisible cluster. +#' Note that it is an expert parameter. The default value should be good enough +#' for most cases. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.bisectingKmeans} returns a fitted bisecting k-means model. +#' @rdname spark.bisectingKmeans +#' @aliases spark.bisectingKmeans,SparkDataFrame,formula-method +#' @name spark.bisectingKmeans +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.bisectingKmeans(df, Class ~ Survived, k = 4) +#' summary(model) +#' +#' # get fitted result from a bisecting k-means model +#' fitted.model <- fitted(model, "centers") +#' showDF(fitted.model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Class", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.bisectingKmeans since 2.2.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} +setMethod("spark.bisectingKmeans", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, k = 4, maxIter = 20, seed = NULL, minDivisibleClusterSize = 1.0) { + formula <- paste0(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + jobj <- callJStatic("org.apache.spark.ml.r.BisectingKMeansWrapper", "fit", + data@sdf, formula, as.integer(k), as.integer(maxIter), + seed, as.numeric(minDivisibleClusterSize)) + new("BisectingKMeansModel", jobj = jobj) + }) + +# Get the summary of a bisecting k-means model + +#' @param object a fitted bisecting k-means model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{k} (number of cluster centers), +#' \code{coefficients} (model cluster centers), +#' \code{size} (number of data points in each cluster), \code{cluster} +#' (cluster centers of the transformed data; cluster is NULL if is.loaded is TRUE), +#' and \code{is.loaded} (whether the model is loaded from a saved file). +#' @rdname spark.bisectingKmeans +#' @export +#' @note summary(BisectingKMeansModel) since 2.2.0 +setMethod("summary", signature(object = "BisectingKMeansModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + features <- callJMethod(jobj, "features") + coefficients <- callJMethod(jobj, "coefficients") + k <- callJMethod(jobj, "k") + size <- callJMethod(jobj, "size") + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + cluster <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "cluster")) + } + list(k = k, coefficients = coefficients, size = size, + cluster = cluster, is.loaded = is.loaded) + }) + +# Predicted values based on a bisecting k-means model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on a bisecting k-means model. +#' @rdname spark.bisectingKmeans +#' @export +#' @note predict(BisectingKMeansModel) since 2.2.0 +setMethod("predict", signature(object = "BisectingKMeansModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' Get fitted result from a bisecting k-means model +#' +#' Get fitted result from a bisecting k-means model. +#' Note: A saved-loaded model does not support this method. +#' +#' @param method type of fitted results, \code{"centers"} for cluster centers +#' or \code{"classes"} for assigned classes. +#' @return \code{fitted} returns a SparkDataFrame containing fitted values. +#' @rdname spark.bisectingKmeans +#' @export +#' @note fitted since 2.2.0 +setMethod("fitted", signature(object = "BisectingKMeansModel"), + function(object, method = c("centers", "classes")) { + method <- match.arg(method) + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + if (is.loaded) { + stop("Saved-loaded bisecting k-means model does not support 'fitted' method") + } else { + dataFrame(callJMethod(jobj, "fitted", method)) + } + }) + +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.bisectingKmeans +#' @export +#' @note write.ml(BisectingKMeansModel, character) since 2.2.0 +setMethod("write.ml", signature(object = "BisectingKMeansModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Multivariate Gaussian Mixture Model (GMM) +#' +#' Fits multivariate gaussian mixture model against a SparkDataFrame, similarly to R's +#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model, +#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} +#' to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' Note that the response variable of formula is empty in spark.gaussianMixture. +#' @param k number of independent Gaussians in the mixture model. +#' @param maxIter maximum iteration number. +#' @param tol the convergence tolerance. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method +#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model. +#' @rdname spark.gaussianMixture +#' @name spark.gaussianMixture +#' @seealso mixtools: \url{https://cran.r-project.org/package=mixtools} +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' library(mvtnorm) +#' set.seed(100) +#' a <- rmvnorm(4, c(0, 0)) +#' b <- rmvnorm(6, c(3, 4)) +#' data <- rbind(a, b) +#' df <- createDataFrame(as.data.frame(data)) +#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2) +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "V1", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.gaussianMixture since 2.1.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} +setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, k = 2, maxIter = 100, tol = 0.01) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit", data@sdf, + formula, as.integer(k), as.integer(maxIter), as.numeric(tol)) + new("GaussianMixtureModel", jobj = jobj) + }) + +# Get the summary of a multivariate gaussian mixture model + +#' @param object a fitted gaussian mixture model. +#' @return \code{summary} returns summary of the fitted model, which is a list. +#' The list includes the model's \code{lambda} (lambda), \code{mu} (mu), +#' \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior). +#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method +#' @rdname spark.gaussianMixture +#' @export +#' @note summary(GaussianMixtureModel) since 2.1.0 +setMethod("summary", signature(object = "GaussianMixtureModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + lambda <- unlist(callJMethod(jobj, "lambda")) + muList <- callJMethod(jobj, "mu") + sigmaList <- callJMethod(jobj, "sigma") + k <- callJMethod(jobj, "k") + dim <- callJMethod(jobj, "dim") + loglik <- callJMethod(jobj, "logLikelihood") + mu <- c() + for (i in 1 : k) { + start <- (i - 1) * dim + 1 + end <- i * dim + mu[[i]] <- unlist(muList[start : end]) + } + sigma <- c() + for (i in 1 : k) { + start <- (i - 1) * dim * dim + 1 + end <- i * dim * dim + sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim)) + } + posterior <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "posterior")) + } + list(lambda = lambda, mu = mu, sigma = sigma, loglik = loglik, + posterior = posterior, is.loaded = is.loaded) + }) + +# Predicted values based on a gaussian mixture model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named +#' "prediction". +#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method +#' @rdname spark.gaussianMixture +#' @export +#' @note predict(GaussianMixtureModel) since 2.1.0 +setMethod("predict", signature(object = "GaussianMixtureModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,GaussianMixtureModel,character-method +#' @rdname spark.gaussianMixture +#' @export +#' @note write.ml(GaussianMixtureModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' K-Means Clustering Model +#' +#' Fits a k-means clustering model against a SparkDataFrame, similarly to R's kmeans(). +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' Note that the response variable of formula is empty in spark.kmeans. +#' @param k number of centers. +#' @param maxIter maximum iteration number. +#' @param initMode the initialization algorithm choosen to fit the model. +#' @param seed the random seed for cluster initialization. +#' @param initSteps the number of steps for the k-means|| initialization mode. +#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' @param tol convergence tolerance of iterations. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.kmeans} returns a fitted k-means model. +#' @rdname spark.kmeans +#' @aliases spark.kmeans,SparkDataFrame,formula-method +#' @name spark.kmeans +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.kmeans(df, Class ~ Survived, k = 4, initMode = "random") +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Class", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.kmeans since 2.0.0 +#' @seealso \link{predict}, \link{read.ml}, \link{write.ml} +setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random"), + seed = NULL, initSteps = 2, tol = 1E-4) { + formula <- paste(deparse(formula), collapse = "") + initMode <- match.arg(initMode) + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula, + as.integer(k), as.integer(maxIter), initMode, seed, + as.integer(initSteps), as.numeric(tol)) + new("KMeansModel", jobj = jobj) + }) + +# Get the summary of a k-means model + +#' @param object a fitted k-means model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{k} (the configured number of cluster centers), +#' \code{coefficients} (model cluster centers), +#' \code{size} (number of data points in each cluster), \code{cluster} +#' (cluster centers of the transformed data), {is.loaded} (whether the model is loaded +#' from a saved file), and \code{clusterSize} +#' (the actual number of cluster centers. When using initMode = "random", +#' \code{clusterSize} may not equal to \code{k}). +#' @rdname spark.kmeans +#' @export +#' @note summary(KMeansModel) since 2.0.0 +setMethod("summary", signature(object = "KMeansModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + features <- callJMethod(jobj, "features") + coefficients <- callJMethod(jobj, "coefficients") + k <- callJMethod(jobj, "k") + size <- callJMethod(jobj, "size") + clusterSize <- callJMethod(jobj, "clusterSize") + coefficients <- t(matrix(unlist(coefficients), ncol = clusterSize)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:clusterSize + cluster <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "cluster")) + } + list(k = k, coefficients = coefficients, size = size, + cluster = cluster, is.loaded = is.loaded, clusterSize = clusterSize) + }) + +# Predicted values based on a k-means model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns the predicted values based on a k-means model. +#' @rdname spark.kmeans +#' @export +#' @note predict(KMeansModel) since 2.0.0 +setMethod("predict", signature(object = "KMeansModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' Get fitted result from a k-means model +#' +#' Get fitted result from a k-means model, similarly to R's fitted(). +#' Note: A saved-loaded model does not support this method. +#' +#' @param object a fitted k-means model. +#' @param method type of fitted results, \code{"centers"} for cluster centers +#' or \code{"classes"} for assigned classes. +#' @param ... additional argument(s) passed to the method. +#' @return \code{fitted} returns a SparkDataFrame containing fitted values. +#' @rdname fitted +#' @export +#' @examples +#' \dontrun{ +#' model <- spark.kmeans(trainingData, ~ ., 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +#' @note fitted since 2.0.0 +setMethod("fitted", signature(object = "KMeansModel"), + function(object, method = c("centers", "classes")) { + method <- match.arg(method) + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + if (is.loaded) { + stop("Saved-loaded k-means model does not support 'fitted' method") + } else { + dataFrame(callJMethod(jobj, "fitted", method)) + } + }) + +# Save fitted MLlib model to the input path + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.kmeans +#' @export +#' @note write.ml(KMeansModel, character) since 2.0.0 +setMethod("write.ml", signature(object = "KMeansModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Latent Dirichlet Allocation +#' +#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call +#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute +#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new +#' data and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data A SparkDataFrame for training. +#' @param features Features column name. Either libSVM-format column or character-format column is +#' valid. +#' @param k Number of topics. +#' @param maxIter Maximum iterations. +#' @param optimizer Optimizer to train an LDA model, "online" or "em", default is "online". +#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in +#' each iteration of mini-batch gradient descent, in range (0, 1]. +#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for +#' the prior placed on topic distributions over terms, default -1 to set automatically on the +#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size +#' numeric is accepted. +#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the +#' prior placed on documents distributions over topics (\code{theta}), default -1 to set +#' automatically on the Spark side. Use \code{summary} to retrieve the effective +#' docConcentration. Only 1-size or \code{k}-size numeric is accepted. +#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the +#' parameter if libSVM-format column is used as the features column. +#' @param maxVocabSize maximum vocabulary size, default 1 << 18 +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model. +#' @rdname spark.lda +#' @aliases spark.lda,SparkDataFrame-method +#' @seealso topicmodels: \url{https://cran.r-project.org/package=topicmodels} +#' @export +#' @examples +#' \dontrun{ +#' text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm") +#' model <- spark.lda(data = text, optimizer = "em") +#' +#' # get a summary of the model +#' summary(model) +#' +#' # compute posterior probabilities +#' posterior <- spark.posterior(model, text) +#' showDF(posterior) +#' +#' # compute perplexity +#' perplexity <- spark.perplexity(model, text) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.lda since 2.1.0 +setMethod("spark.lda", signature(data = "SparkDataFrame"), + function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"), + subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1, + customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) { + optimizer <- match.arg(optimizer) + jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features, + as.integer(k), as.integer(maxIter), optimizer, + as.numeric(subsamplingRate), topicConcentration, + as.array(docConcentration), as.array(customizedStopWords), + maxVocabSize) + new("LDAModel", jobj = jobj) + }) + +# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda} + +#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}. +#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes +#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for +#' the prior placed on documents distributions over topics \code{theta}} +#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or +#' \code{eta} for the prior placed on topic distributions over terms} +#' \item{\code{logLikelihood}}{log likelihood of the entire corpus} +#' \item{\code{logPerplexity}}{log perplexity} +#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model} +#' \item{\code{vocabSize}}{number of terms in the corpus} +#' \item{\code{topics}}{top 10 terms and their weights of all topics} +#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file +#' used as training set} +#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set, +#' given the current parameter estimates: +#' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) +#' It is only for distributed LDA model (i.e., optimizer = "em")} +#' \item{\code{logPrior}}{Log probability of the current parameter estimate: +#' log P(topics, topic distributions for docs | Dirichlet hyperparameters) +#' It is only for distributed LDA model (i.e., optimizer = "em")} +#' @rdname spark.lda +#' @aliases summary,LDAModel-method +#' @export +#' @note summary(LDAModel) since 2.1.0 +setMethod("summary", signature(object = "LDAModel"), + function(object, maxTermsPerTopic) { + maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic)) + jobj <- object@jobj + docConcentration <- callJMethod(jobj, "docConcentration") + topicConcentration <- callJMethod(jobj, "topicConcentration") + logLikelihood <- callJMethod(jobj, "logLikelihood") + logPerplexity <- callJMethod(jobj, "logPerplexity") + isDistributed <- callJMethod(jobj, "isDistributed") + vocabSize <- callJMethod(jobj, "vocabSize") + topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic)) + vocabulary <- callJMethod(jobj, "vocabulary") + trainingLogLikelihood <- if (isDistributed) { + callJMethod(jobj, "trainingLogLikelihood") + } else { + NA + } + logPrior <- if (isDistributed) { + callJMethod(jobj, "logPrior") + } else { + NA + } + list(docConcentration = unlist(docConcentration), + topicConcentration = topicConcentration, + logLikelihood = logLikelihood, logPerplexity = logPerplexity, + isDistributed = isDistributed, vocabSize = vocabSize, + topics = topics, vocabulary = unlist(vocabulary), + trainingLogLikelihood = trainingLogLikelihood, logPrior = logPrior) + }) + +# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda} + +#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log +#' perplexity of the training data if missing argument "data". +#' @rdname spark.lda +#' @aliases spark.perplexity,LDAModel-method +#' @export +#' @note spark.perplexity(LDAModel) since 2.1.0 +setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"), + function(object, data) { + ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"), + callJMethod(object@jobj, "computeLogPerplexity", data@sdf)) + }) + +# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda() + +#' @param newData A SparkDataFrame for testing. +#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities +#' vectors named "topicDistribution". +#' @rdname spark.lda +#' @aliases spark.posterior,LDAModel,SparkDataFrame-method +#' @export +#' @note spark.posterior(LDAModel) since 2.1.0 +setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the Latent Dirichlet Allocation model to the input path. + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.lda +#' @aliases write.ml,LDAModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(LDAModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "LDAModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R new file mode 100644 index 0000000000000..dfcb45a1b66c9 --- /dev/null +++ b/R/pkg/R/mllib_fpm.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms integration + +#' S4 class that represents a FPGrowthModel +#' +#' @param jobj a Java object reference to the backing Scala FPGrowthModel +#' @export +#' @note FPGrowthModel since 2.2.0 +setClass("FPGrowthModel", slots = list(jobj = "jobj")) + +#' FP-growth +#' +#' A parallel FP-growth algorithm to mine frequent itemsets. +#' \code{spark.fpGrowth} fits a FP-growth model on a SparkDataFrame. Users can +#' \code{spark.freqItemsets} to get frequent itemsets, \code{spark.associationRules} to get +#' association rules, \code{predict} to make predictions on new data based on generated association +#' rules, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{ +#' FP-growth}. +#' +#' @param data A SparkDataFrame for training. +#' @param minSupport Minimal support level. +#' @param minConfidence Minimal confidence level. +#' @param itemsCol Features column name. +#' @param numPartitions Number of partitions used for fitting. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model. +#' @rdname spark.fpGrowth +#' @name spark.fpGrowth +#' @aliases spark.fpGrowth,SparkDataFrame-method +#' @export +#' @examples +#' \dontrun{ +#' raw_data <- read.df( +#' "data/mllib/sample_fpgrowth.txt", +#' source = "csv", +#' schema = structType(structField("raw_items", "string"))) +#' +#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items") +#' model <- spark.fpGrowth(data) +#' +#' # Show frequent itemsets +#' frequent_itemsets <- spark.freqItemsets(model) +#' showDF(frequent_itemsets) +#' +#' # Show association rules +#' association_rules <- spark.associationRules(model) +#' showDF(association_rules) +#' +#' # Predict on new data +#' new_itemsets <- data.frame(items = c("t", "t,s")) +#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as items") +#' predict(model, new_data) +#' +#' # Save and load model +#' path <- "/path/to/model" +#' write.ml(model, path) +#' read.ml(path) +#' +#' # Optional arguments +#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as baskets") +#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5, +#' itemsCol = "baskets", numPartitions = 10) +#' } +#' @note spark.fpGrowth since 2.2.0 +setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"), + function(data, minSupport = 0.3, minConfidence = 0.8, + itemsCol = "items", numPartitions = NULL) { + if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) { + stop("minSupport should be a number [0, 1].") + } + if (!is.numeric(minConfidence) || minConfidence < 0 || minConfidence > 1) { + stop("minConfidence should be a number [0, 1].") + } + if (!is.null(numPartitions)) { + numPartitions <- as.integer(numPartitions) + stopifnot(numPartitions > 0) + } + + jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit", + data@sdf, as.numeric(minSupport), as.numeric(minConfidence), + itemsCol, numPartitions) + new("FPGrowthModel", jobj = jobj) + }) + +# Get frequent itemsets. + +#' @param object a fitted FPGrowth model. +#' @return A \code{SparkDataFrame} with frequent itemsets. +#' The \code{SparkDataFrame} contains two columns: +#' \code{items} (an array of the same type as the input column) +#' and \code{freq} (frequency of the itemset). +#' @rdname spark.fpGrowth +#' @aliases freqItemsets,FPGrowthModel-method +#' @export +#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0 +setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "freqItemsets")) + }) + +# Get association rules. + +#' @return A \code{SparkDataFrame} with association rules. +#' The \code{SparkDataFrame} contains three columns: +#' \code{antecedent} (an array of the same type as the input column), +#' \code{consequent} (an array of the same type as the input column), +#' and \code{condfidence} (confidence). +#' @rdname spark.fpGrowth +#' @aliases associationRules,FPGrowthModel-method +#' @export +#' @note spark.associationRules(FPGrowthModel) since 2.2.0 +setMethod("spark.associationRules", signature(object = "FPGrowthModel"), + function(object) { + dataFrame(callJMethod(object@jobj, "associationRules")) + }) + +# Makes predictions based on generated association rules + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.fpGrowth +#' @aliases predict,FPGrowthModel-method +#' @export +#' @note predict(FPGrowthModel) since 2.2.0 +setMethod("predict", signature(object = "FPGrowthModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the FPGrowth model to the output path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' @rdname spark.fpGrowth +#' @aliases write.ml,FPGrowthModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(FPGrowthModel, character) since 2.2.0 +setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_recommendation.R b/R/pkg/R/mllib_recommendation.R new file mode 100644 index 0000000000000..fa794249085d7 --- /dev/null +++ b/R/pkg/R/mllib_recommendation.R @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_recommendation.R: Provides methods for MLlib recommendation algorithms integration + +#' S4 class that represents an ALSModel +#' +#' @param jobj a Java object reference to the backing Scala ALSWrapper +#' @export +#' @note ALSModel since 2.1.0 +setClass("ALSModel", representation(jobj = "jobj")) + +#' Alternating Least Squares (ALS) for Collaborative Filtering +#' +#' \code{spark.als} learns latent factors in collaborative filtering via alternating least +#' squares. Users can call \code{summary} to obtain fitted latent factors, \code{predict} +#' to make predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-collaborative-filtering.html}{MLlib: +#' Collaborative Filtering}. +#' +#' @param data a SparkDataFrame for training. +#' @param ratingCol column name for ratings. +#' @param userCol column name for user ids. Ids must be (or can be coerced into) integers. +#' @param itemCol column name for item ids. Ids must be (or can be coerced into) integers. +#' @param rank rank of the matrix factorization (> 0). +#' @param regParam regularization parameter (>= 0). +#' @param maxIter maximum number of iterations (>= 0). +#' @param nonnegative logical value indicating whether to apply nonnegativity constraints. +#' @param implicitPrefs logical value indicating whether to use implicit preference. +#' @param alpha alpha parameter in the implicit preference formulation (>= 0). +#' @param seed integer seed for random number generation. +#' @param numUserBlocks number of user blocks used to parallelize computation (> 0). +#' @param numItemBlocks number of item blocks used to parallelize computation (> 0). +#' @param checkpointInterval number of checkpoint intervals (>= 1) or disable checkpoint (-1). +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.als} returns a fitted ALS model. +#' @rdname spark.als +#' @aliases spark.als,SparkDataFrame-method +#' @name spark.als +#' @export +#' @examples +#' \dontrun{ +#' ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), +#' list(2, 1, 1.0), list(2, 2, 5.0)) +#' df <- createDataFrame(ratings, c("user", "item", "rating")) +#' model <- spark.als(df, "rating", "user", "item") +#' +#' # extract latent factors +#' stats <- summary(model) +#' userFactors <- stats$userFactors +#' itemFactors <- stats$itemFactors +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # set other arguments +#' modelS <- spark.als(df, "rating", "user", "item", rank = 20, +#' regParam = 0.1, nonnegative = TRUE) +#' statsS <- summary(modelS) +#' } +#' @note spark.als since 2.1.0 +setMethod("spark.als", signature(data = "SparkDataFrame"), + function(data, ratingCol = "rating", userCol = "user", itemCol = "item", + rank = 10, regParam = 0.1, maxIter = 10, nonnegative = FALSE, + implicitPrefs = FALSE, alpha = 1.0, numUserBlocks = 10, numItemBlocks = 10, + checkpointInterval = 10, seed = 0) { + + if (!is.numeric(rank) || rank <= 0) { + stop("rank should be a positive number.") + } + if (!is.numeric(regParam) || regParam < 0) { + stop("regParam should be a nonnegative number.") + } + if (!is.numeric(maxIter) || maxIter <= 0) { + stop("maxIter should be a positive number.") + } + + jobj <- callJStatic("org.apache.spark.ml.r.ALSWrapper", + "fit", data@sdf, ratingCol, userCol, itemCol, as.integer(rank), + regParam, as.integer(maxIter), implicitPrefs, alpha, nonnegative, + as.integer(numUserBlocks), as.integer(numItemBlocks), + as.integer(checkpointInterval), as.integer(seed)) + new("ALSModel", jobj = jobj) + }) + +# Returns a summary of the ALS model produced by spark.als. + +#' @param object a fitted ALS model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes \code{user} (the names of the user column), +#' \code{item} (the item column), \code{rating} (the rating column), \code{userFactors} +#' (the estimated user factors), \code{itemFactors} (the estimated item factors), +#' and \code{rank} (rank of the matrix factorization model). +#' @rdname spark.als +#' @aliases summary,ALSModel-method +#' @export +#' @note summary(ALSModel) since 2.1.0 +setMethod("summary", signature(object = "ALSModel"), + function(object) { + jobj <- object@jobj + user <- callJMethod(jobj, "userCol") + item <- callJMethod(jobj, "itemCol") + rating <- callJMethod(jobj, "ratingCol") + userFactors <- dataFrame(callJMethod(jobj, "userFactors")) + itemFactors <- dataFrame(callJMethod(jobj, "itemFactors")) + rank <- callJMethod(jobj, "rank") + list(user = user, item = item, rating = rating, userFactors = userFactors, + itemFactors = itemFactors, rank = rank) + }) + +# Makes predictions from an ALS model or a model produced by spark.als. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.als +#' @aliases predict,ALSModel-method +#' @export +#' @note predict(ALSModel) since 2.1.0 +setMethod("predict", signature(object = "ALSModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the ALS model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite logical value indicating whether to overwrite if the output path +#' already exists. Default is FALSE which means throw exception +#' if the output path exists. +#' +#' @rdname spark.als +#' @aliases write.ml,ALSModel,character-method +#' @export +#' @seealso \link{read.ml} +#' @note write.ml(ALSModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "ALSModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R new file mode 100644 index 0000000000000..d59c890f3e5fd --- /dev/null +++ b/R/pkg/R/mllib_regression.R @@ -0,0 +1,500 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_regression.R: Provides methods for MLlib regression algorithms +# (except for tree-based algorithms) integration + +#' S4 class that represents a AFTSurvivalRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper +#' @export +#' @note AFTSurvivalRegressionModel since 2.0.0 +setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a generalized linear model +#' +#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper +#' @export +#' @note GeneralizedLinearRegressionModel since 2.0.0 +setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents an IsotonicRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala IsotonicRegressionModel +#' @export +#' @note IsotonicRegressionModel since 2.1.0 +setClass("IsotonicRegressionModel", representation(jobj = "jobj")) + +#' Generalized Linear Models +#' +#' Fits generalized linear model against a SparkDataFrame. +#' Users can call \code{summary} to print a summary of the fitted model, \code{predict} to make +#' predictions on new data, and \code{write.ml}/\code{read.ml} to save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param family a description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, \code{poisson} and \code{tweedie}. +#' +#' Note that there are two ways to specify the tweedie family. +#' \itemize{ +#' \item Set \code{family = "tweedie"} and specify the var.power and link.power; +#' \item When package \code{statmod} is loaded, the tweedie family is specified using the +#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' } +#' @param tol positive convergence tolerance of iterations. +#' @param maxIter integer giving the maximal number of IRLS iterations. +#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance +#' weights as 1.0. +#' @param regParam regularization parameter for L2 regularization. +#' @param var.power the power in the variance function of the Tweedie distribution which provides +#' the relationship between the variance and mean of the distribution. Only +#' applicable to the Tweedie family. +#' @param link.power the index in the power link function. Only applicable to the Tweedie family. +#' @param ... additional arguments passed to the method. +#' @aliases spark.glm,SparkDataFrame,formula-method +#' @return \code{spark.glm} returns a fitted generalized linear model. +#' @rdname spark.glm +#' @name spark.glm +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") +#' summary(model) +#' +#' # fitted values on training data +#' fitted <- predict(model, df) +#' head(select(fitted, "Freq", "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit tweedie model +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", +#' var.power = 1.2, link.power = 0) +#' summary(model) +#' +#' # use the tweedie family from statmod +#' library(statmod) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0)) +#' summary(model) +#' } +#' @note spark.glm since 2.0.0 +#' @seealso \link{glm}, \link{read.ml} +setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, + regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { + + if (is.character(family)) { + # Handle when family = "tweedie" + if (tolower(family) == "tweedie") { + family <- list(family = "tweedie", link = NULL) + } else { + family <- get(family, mode = "function", envir = parent.frame()) + } + } + if (is.function(family)) { + family <- family() + } + if (is.null(family$family)) { + print(family) + stop("'family' not recognized") + } + # Handle when family = statmod::tweedie() + if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { + var.power <- log(family$variance(exp(1))) + link.power <- log(family$linkfun(exp(1))) + family <- list(family = "tweedie", link = NULL) + } + + formula <- paste(deparse(formula), collapse = "") + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) + } + + # For known families, Gamma is upper-cased + jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", + "fit", formula, data@sdf, tolower(family$family), family$link, + tol, as.integer(maxIter), weightCol, regParam, + as.double(var.power), as.double(link.power)) + new("GeneralizedLinearRegressionModel", jobj = jobj) + }) + +#' Generalized Linear Models (R-compliant) +#' +#' Fits a generalized linear model, similarly to R's glm(). +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param data a SparkDataFrame or R's glm data for training. +#' @param family a description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{poisson}, \code{Gamma}, and \code{tweedie}. +#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance +#' weights as 1.0. +#' @param epsilon positive convergence tolerance of iterations. +#' @param maxit integer giving the maximal number of IRLS iterations. +#' @param var.power the index of the power variance function in the Tweedie family. +#' @param link.power the index of the power link function in the Tweedie family. +#' @return \code{glm} returns a fitted generalized linear model. +#' @rdname glm +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- glm(Freq ~ Sex + Age, df, family = "gaussian") +#' summary(model) +#' } +#' @note glm since 1.5.0 +#' @seealso \link{spark.glm} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), + function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, + var.power = 0.0, link.power = 1.0 - var.power) { + spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, + var.power = var.power, link.power = link.power) + }) + +# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). + +#' @param object a fitted generalized linear model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes +#' coefficients, standard error of coefficients, t value and p value), +#' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, +#' the coefficients matrix only provides coefficients. +#' @rdname spark.glm +#' @export +#' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 +setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), + function(object) { + jobj <- object@jobj + is.loaded <- callJMethod(jobj, "isLoaded") + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + dispersion <- callJMethod(jobj, "rDispersion") + null.deviance <- callJMethod(jobj, "rNullDeviance") + deviance <- callJMethod(jobj, "rDeviance") + df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") + df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") + iter <- callJMethod(jobj, "rNumIterations") + family <- callJMethod(jobj, "rFamily") + aic <- callJMethod(jobj, "rAic") + if (family == "tweedie" && aic == 0) aic <- NA + deviance.resid <- if (is.loaded) { + NULL + } else { + dataFrame(callJMethod(jobj, "rDevianceResiduals")) + } + # If the underlying WeightedLeastSquares using "normal" solver, we can provide + # coefficients, standard error of coefficients, t value and p value. Otherwise, + # it will be fitted by local "l-bfgs", we can only provide coefficients. + if (length(features) == length(coefficients)) { + coefficients <- matrix(unlist(coefficients), ncol = 1) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + } else { + coefficients <- matrix(unlist(coefficients), ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + } + ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, + dispersion = dispersion, null.deviance = null.deviance, + deviance = deviance, df.null = df.null, df.residual = df.residual, + aic = aic, iter = iter, family = family, is.loaded = is.loaded) + class(ans) <- "summary.GeneralizedLinearRegressionModel" + ans + }) + +# Prints the summary of GeneralizedLinearRegressionModel + +#' @rdname spark.glm +#' @param x summary object of fitted generalized linear model returned by \code{summary} function. +#' @export +#' @note print.summary.GeneralizedLinearRegressionModel since 2.0.0 +print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { + if (x$is.loaded) { + cat("\nSaved-loaded model does not support output 'Deviance Residuals'.\n") + } else { + x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", + c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max")) + x$deviance.resid <- zapsmall(x$deviance.resid, 5L) + cat("\nDeviance Residuals: \n") + cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") + print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) + } + + cat("\nCoefficients:\n") + print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L) + + cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion), + ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"), + format(unlist(x[c("null.deviance", "deviance")]), digits = 5L), + " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"), + 1L, paste, collapse = " "), sep = "") + cat("AIC: ", format(x$aic, digits = 4L), "\n\n", + "Number of Fisher Scoring iterations: ", x$iter, "\n\n", sep = "") + invisible(x) + } + +# Makes predictions from a generalized linear model produced by glm() or spark.glm(), +# similarly to R's predict(). + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column named +#' "prediction". +#' @rdname spark.glm +#' @export +#' @note predict(GeneralizedLinearRegressionModel) since 1.5.0 +setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the generalized linear model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.glm +#' @export +#' @note write.ml(GeneralizedLinearRegressionModel, character) since 2.0.0 +setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Isotonic Regression Model +#' +#' Fits an Isotonic Regression model against a SparkDataFrame, similarly to R's isoreg(). +#' Users can print, make predictions on the produced model and save the model to the input path. +#' +#' @param data SparkDataFrame for training. +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', '.', ':', '+', and '-'. +#' @param isotonic Whether the output sequence should be isotonic/increasing (TRUE) or +#' antitonic/decreasing (FALSE). +#' @param featureIndex The index of the feature if \code{featuresCol} is a vector column +#' (default: 0), no effect otherwise. +#' @param weightCol The weight column name. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.isoreg} returns a fitted Isotonic Regression model. +#' @rdname spark.isoreg +#' @aliases spark.isoreg,SparkDataFrame,formula-method +#' @name spark.isoreg +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' data <- list(list(7.0, 0.0), list(5.0, 1.0), list(3.0, 2.0), +#' list(5.0, 3.0), list(1.0, 4.0)) +#' df <- createDataFrame(data, c("label", "feature")) +#' model <- spark.isoreg(df, label ~ feature, isotonic = FALSE) +#' # return model boundaries and prediction as lists +#' result <- summary(model, df) +#' # prediction based on fitted model +#' predict_data <- list(list(-2.0), list(-1.0), list(0.5), +#' list(0.75), list(1.0), list(2.0), list(9.0)) +#' predict_df <- createDataFrame(predict_data, c("feature")) +#' # get prediction column +#' predict_result <- collect(select(predict(model, predict_df), "prediction")) +#' +#' # save fitted model to input path +#' path <- "path/to/model" +#' write.ml(model, path) +#' +#' # can also read back the saved model and print +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.isoreg since 2.1.0 +setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { + formula <- paste(deparse(formula), collapse = "") + + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) + } + + jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", + data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), + weightCol) + new("IsotonicRegressionModel", jobj = jobj) + }) + +# Get the summary of an IsotonicRegressionModel model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes model's \code{boundaries} (boundaries in increasing order) +#' and \code{predictions} (predictions associated with the boundaries at the same index). +#' @rdname spark.isoreg +#' @aliases summary,IsotonicRegressionModel-method +#' @export +#' @note summary(IsotonicRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "IsotonicRegressionModel"), + function(object) { + jobj <- object@jobj + boundaries <- callJMethod(jobj, "boundaries") + predictions <- callJMethod(jobj, "predictions") + list(boundaries = boundaries, predictions = predictions) + }) + +# Predicted values based on an isotonicRegression model + +#' @param object a fitted IsotonicRegressionModel. +#' @param newData SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values. +#' @rdname spark.isoreg +#' @aliases predict,IsotonicRegressionModel,SparkDataFrame-method +#' @export +#' @note predict(IsotonicRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "IsotonicRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save fitted IsotonicRegressionModel to the input path + +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @rdname spark.isoreg +#' @aliases write.ml,IsotonicRegressionModel,character-method +#' @export +#' @note write.ml(IsotonicRegression, character) since 2.1.0 +setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Accelerated Failure Time (AFT) Survival Regression Model +#' +#' \code{spark.survreg} fits an accelerated failure time (AFT) survival regression model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted AFT model, +#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' Note that operator '.' is not supported currently. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. +#' @return \code{spark.survreg} returns a fitted AFT survival regression model. +#' @rdname spark.survreg +#' @seealso survival: \url{https://cran.r-project.org/package=survival} +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(ovarian) +#' model <- spark.survreg(df, Surv(futime, fustat) ~ ecog_ps + rx) +#' +#' # get a summary of the model +#' summary(model) +#' +#' # make predictions +#' predicted <- predict(model, df) +#' showDF(predicted) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' } +#' @note spark.survreg since 2.0.0 +setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, aggregationDepth = 2) { + formula <- paste(deparse(formula), collapse = "") + jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", + "fit", formula, data@sdf, as.integer(aggregationDepth)) + new("AFTSurvivalRegressionModel", jobj = jobj) + }) + +# Returns a summary of the AFT survival regression model produced by spark.survreg, +# similarly to R's summary(). + +#' @param object a fitted AFT survival regression model. +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list includes the model's \code{coefficients} (features, coefficients, +#' intercept and log(scale)). +#' @rdname spark.survreg +#' @export +#' @note summary(AFTSurvivalRegressionModel) since 2.0.0 +setMethod("summary", signature(object = "AFTSurvivalRegressionModel"), + function(object) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Value") + rownames(coefficients) <- unlist(features) + list(coefficients = coefficients) + }) + +# Makes predictions from an AFT survival regression model or a model produced by +# spark.survreg, similarly to R package survival's predict. + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted values +#' on the original scale of the data (mean predicted value at scale = 1.0). +#' @rdname spark.survreg +#' @export +#' @note predict(AFTSurvivalRegressionModel) since 2.0.0 +setMethod("predict", signature(object = "AFTSurvivalRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Saves the AFT survival regression model to the input path. + +#' @param path the directory where the model is saved. +#' @param overwrite overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @rdname spark.survreg +#' @export +#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0 +#' @seealso \link{write.ml} +setMethod("write.ml", signature(object = "AFTSurvivalRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_stat.R b/R/pkg/R/mllib_stat.R new file mode 100644 index 0000000000000..3e013f1d45e38 --- /dev/null +++ b/R/pkg/R/mllib_stat.R @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_stat.R: Provides methods for MLlib statistics algorithms integration + +#' S4 class that represents an KSTest +#' +#' @param jobj a Java object reference to the backing Scala KSTestWrapper +#' @export +#' @note KSTest since 2.1.0 +setClass("KSTest", representation(jobj = "jobj")) + +#' (One-Sample) Kolmogorov-Smirnov Test +#' +#' @description +#' \code{spark.kstest} Conduct the two-sided Kolmogorov-Smirnov (KS) test for data sampled from a +#' continuous distribution. +#' +#' By comparing the largest difference between the empirical cumulative +#' distribution of the sample data and the theoretical distribution we can provide a test for the +#' the null hypothesis that the sample data comes from that theoretical distribution. +#' +#' Users can call \code{summary} to obtain a summary of the test, and \code{print.summary.KSTest} +#' to print out a summary result. +#' +#' @param data a SparkDataFrame of user data. +#' @param testCol column name where the test data is from. It should be a column of double type. +#' @param nullHypothesis name of the theoretical distribution tested against. Currently only +#' \code{"norm"} for normal distribution is supported. +#' @param distParams parameters(s) of the distribution. For \code{nullHypothesis = "norm"}, +#' we can provide as a vector the mean and standard deviation of +#' the distribution. If none is provided, then standard normal will be used. +#' If only one is provided, then the standard deviation will be set to be one. +#' @param ... additional argument(s) passed to the method. +#' @return \code{spark.kstest} returns a test result object. +#' @rdname spark.kstest +#' @aliases spark.kstest,SparkDataFrame-method +#' @name spark.kstest +#' @seealso \href{http://spark.apache.org/docs/latest/mllib-statistics.html#hypothesis-testing}{ +#' MLlib: Hypothesis Testing} +#' @export +#' @examples +#' \dontrun{ +#' data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25)) +#' df <- createDataFrame(data) +#' test <- spark.kstest(df, "test", "norm", c(0, 1)) +#' +#' # get a summary of the test result +#' testSummary <- summary(test) +#' testSummary +#' +#' # print out the summary in an organized way +#' print.summary.KSTest(testSummary) +#' } +#' @note spark.kstest since 2.1.0 +setMethod("spark.kstest", signature(data = "SparkDataFrame"), + function(data, testCol = "test", nullHypothesis = c("norm"), distParams = c(0, 1)) { + tryCatch(match.arg(nullHypothesis), + error = function(e) { + msg <- paste("Distribution", nullHypothesis, "is not supported.") + stop(msg) + }) + if (nullHypothesis == "norm") { + distParams <- as.numeric(distParams) + mu <- ifelse(length(distParams) < 1, 0, distParams[1]) + sigma <- ifelse(length(distParams) < 2, 1, distParams[2]) + jobj <- callJStatic("org.apache.spark.ml.r.KSTestWrapper", + "test", data@sdf, testCol, nullHypothesis, + as.array(c(mu, sigma))) + new("KSTest", jobj = jobj) + } +}) + +# Get the summary of Kolmogorov-Smirnov (KS) Test. + +#' @param object test result object of KSTest by \code{spark.kstest}. +#' @return \code{summary} returns summary information of KSTest object, which is a list. +#' The list includes the \code{p.value} (p-value), \code{statistic} (test statistic +#' computed for the test), \code{nullHypothesis} (the null hypothesis with its +#' parameters tested against) and \code{degreesOfFreedom} (degrees of freedom of the test). +#' @rdname spark.kstest +#' @aliases summary,KSTest-method +#' @export +#' @note summary(KSTest) since 2.1.0 +setMethod("summary", signature(object = "KSTest"), + function(object) { + jobj <- object@jobj + pValue <- callJMethod(jobj, "pValue") + statistic <- callJMethod(jobj, "statistic") + nullHypothesis <- callJMethod(jobj, "nullHypothesis") + distName <- callJMethod(jobj, "distName") + distParams <- unlist(callJMethod(jobj, "distParams")) + degreesOfFreedom <- callJMethod(jobj, "degreesOfFreedom") + + ans <- list(p.value = pValue, statistic = statistic, nullHypothesis = nullHypothesis, + nullHypothesis.name = distName, nullHypothesis.parameters = distParams, + degreesOfFreedom = degreesOfFreedom, jobj = jobj) + class(ans) <- "summary.KSTest" + ans + }) + +# Prints the summary of KSTest + +#' @rdname spark.kstest +#' @param x summary object of KSTest returned by \code{summary}. +#' @export +#' @note print.summary.KSTest since 2.1.0 +print.summary.KSTest <- function(x, ...) { + jobj <- x$jobj + summaryStr <- callJMethod(jobj, "summary") + cat(summaryStr, "\n") + invisible(x) +} diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R new file mode 100644 index 0000000000000..82279be6fbe77 --- /dev/null +++ b/R/pkg/R/mllib_tree.R @@ -0,0 +1,501 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_tree.R: Provides methods for MLlib tree-based algorithms integration + +#' S4 class that represents a GBTRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala GBTRegressionModel +#' @export +#' @note GBTRegressionModel since 2.1.0 +setClass("GBTRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a GBTClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala GBTClassificationModel +#' @export +#' @note GBTClassificationModel since 2.1.0 +setClass("GBTClassificationModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestRegressionModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestRegressionModel +#' @export +#' @note RandomForestRegressionModel since 2.1.0 +setClass("RandomForestRegressionModel", representation(jobj = "jobj")) + +#' S4 class that represents a RandomForestClassificationModel +#' +#' @param jobj a Java object reference to the backing Scala RandomForestClassificationModel +#' @export +#' @note RandomForestClassificationModel since 2.1.0 +setClass("RandomForestClassificationModel", representation(jobj = "jobj")) + +# Create the summary of a tree ensemble model (eg. Random Forest, GBT) +summary.treeEnsemble <- function(model) { + jobj <- model@jobj + formula <- callJMethod(jobj, "formula") + numFeatures <- callJMethod(jobj, "numFeatures") + features <- callJMethod(jobj, "features") + featureImportances <- callJMethod(callJMethod(jobj, "featureImportances"), "toString") + maxDepth <- callJMethod(jobj, "maxDepth") + numTrees <- callJMethod(jobj, "numTrees") + treeWeights <- callJMethod(jobj, "treeWeights") + list(formula = formula, + numFeatures = numFeatures, + features = features, + featureImportances = featureImportances, + maxDepth = maxDepth, + numTrees = numTrees, + treeWeights = treeWeights, + jobj = jobj) +} + +# Prints the summary of tree ensemble models (eg. Random Forest, GBT) +print.summary.treeEnsemble <- function(x) { + jobj <- x$jobj + cat("Formula: ", x$formula) + cat("\nNumber of features: ", x$numFeatures) + cat("\nFeatures: ", unlist(x$features)) + cat("\nFeature importances: ", x$featureImportances) + cat("\nMax Depth: ", x$maxDepth) + cat("\nNumber of trees: ", x$numTrees) + cat("\nTree weights: ", unlist(x$treeWeights)) + + summaryStr <- callJMethod(jobj, "summary") + cat("\n", summaryStr, "\n") + invisible(x) +} + +#' Gradient Boosted Tree Model for Regression and Classification +#' +#' \code{spark.gbt} fits a Gradient Boosted Tree Regression model or Classification model on a +#' SparkDataFrame. Users can call \code{summary} to get a summary of the fitted +#' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and +#' \code{write.ml}/\code{read.ml} to save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ +#' GBT Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ +#' GBT Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param maxIter Param for maximum number of iterations (>= 0). +#' @param stepSize Param for Step size to be used for each iteration of optimization. +#' @param lossType Loss function which GBT tries to minimize. +#' For classification, must be "logistic". For regression, must be one of +#' "squared" (L2) and "absolute" (L1), default is "squared". +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. If a +#' split causes the left or right child to have fewer than +#' minInstancesPerNode, the split will be discarded as invalid. Should be +#' >= 1. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.gbt,SparkDataFrame,formula-method +#' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. +#' @rdname spark.gbt +#' @name spark.gbt +#' @export +#' @examples +#' \dontrun{ +#' # fit a Gradient Boosted Tree Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Gradient Boosted Tree Classification Model +#' # label must be binary - Only binary classification is supported for GBT. +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.gbt(df, Survived ~ Age + Freq, "classification") +#' +#' # numeric label is also supported +#' t2 <- as.data.frame(Titanic) +#' t2$NumericGender <- ifelse(t2$Sex == "Male", 0, 1) +#' df <- createDataFrame(t2) +#' model <- spark.gbt(df, NumericGender ~ ., type = "classification") +#' } +#' @note spark.gbt since 2.1.0 +setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, maxIter = 20, stepSize = 0.1, lossType = NULL, + seed = NULL, subsamplingRate = 1.0, minInstancesPerNode = 1, minInfoGain = 0.0, + checkpointInterval = 10, maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(lossType)) lossType <- "squared" + lossType <- match.arg(lossType, c("squared", "absolute")) + jobj <- callJStatic("org.apache.spark.ml.r.GBTRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(lossType)) lossType <- "logistic" + lossType <- match.arg(lossType, "logistic") + jobj <- callJStatic("org.apache.spark.ml.r.GBTClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(maxIter), + as.numeric(stepSize), as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + lossType, seed, as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("GBTClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Gradient Boosted Tree Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). +#' @rdname spark.gbt +#' @aliases summary,GBTRegressionModel-method +#' @export +#' @note summary(GBTRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "GBTRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTRegressionModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Regression Model + +#' @param x summary object of Gradient Boosted Tree regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTRegressionModel since 2.1.0 +print.summary.GBTRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Get the summary of a Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @aliases summary,GBTClassificationModel-method +#' @export +#' @note summary(GBTClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "GBTClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.GBTClassificationModel" + ans + }) + +# Prints the summary of Gradient Boosted Tree Classification Model + +#' @rdname spark.gbt +#' @export +#' @note print.summary.GBTClassificationModel since 2.1.0 +print.summary.GBTClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Makes predictions from a Gradient Boosted Tree Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.gbt +#' @aliases predict,GBTRegressionModel-method +#' @export +#' @note predict(GBTRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "GBTRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.gbt +#' @aliases predict,GBTClassificationModel-method +#' @export +#' @note predict(GBTClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "GBTClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Gradient Boosted Tree Regression or Classification model to the input path. + +#' @param object A fitted Gradient Boosted Tree regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' @aliases write.ml,GBTRegressionModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,GBTClassificationModel,character-method +#' @rdname spark.gbt +#' @export +#' @note write.ml(GBTClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "GBTClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' Random Forest Model for Regression and Classification +#' +#' \code{spark.randomForest} fits a Random Forest Regression model or Classification model on +#' a SparkDataFrame. Users can call \code{summary} to get a summary of the fitted Random Forest +#' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to +#' save/load fitted models. +#' For more details, see +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ +#' Random Forest Regression} and +#' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ +#' Random Forest Classification} +#' +#' @param data a SparkDataFrame for training. +#' @param formula a symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~', ':', '+', and '-'. +#' @param type type of model, one of "regression" or "classification", to fit +#' @param maxDepth Maximum depth of the tree (>= 0). +#' @param maxBins Maximum number of bins used for discretizing continuous features and for choosing +#' how to split on features at each node. More bins give higher granularity. Must be +#' >= 2 and >= number of categories in any categorical feature. +#' @param numTrees Number of trees to train (>= 1). +#' @param impurity Criterion used for information gain calculation. +#' For regression, must be "variance". For classification, must be one of +#' "entropy" and "gini", default is "gini". +#' @param featureSubsetStrategy The number of features to consider for splits at each tree node. +#' Supported options: "auto", "all", "onethird", "sqrt", "log2", (0.0-1.0], [1-n]. +#' @param seed integer seed for random number generation. +#' @param subsamplingRate Fraction of the training data used for learning each decision tree, in +#' range (0, 1]. +#' @param minInstancesPerNode Minimum number of instances each child must have after split. +#' @param minInfoGain Minimum information gain for a split to be considered at a tree node. +#' @param checkpointInterval Param for set checkpoint interval (>= 1) or disable checkpoint (-1). +#' @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. +#' @param cacheNodeIds If FALSE, the algorithm will pass trees to executors to match instances with +#' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching +#' can speed up training of deeper trees. Users can set how often should the +#' cache be checkpointed or disable it by setting checkpointInterval. +#' @param ... additional arguments passed to the method. +#' @aliases spark.randomForest,SparkDataFrame,formula-method +#' @return \code{spark.randomForest} returns a fitted Random Forest model. +#' @rdname spark.randomForest +#' @name spark.randomForest +#' @export +#' @examples +#' \dontrun{ +#' # fit a Random Forest Regression Model +#' df <- createDataFrame(longley) +#' model <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 5, maxBins = 16) +#' +#' # get the summary of the model +#' summary(model) +#' +#' # make predictions +#' predictions <- predict(model, df) +#' +#' # save and load the model +#' path <- "path/to/model" +#' write.ml(model, path) +#' savedModel <- read.ml(path) +#' summary(savedModel) +#' +#' # fit a Random Forest Classification Model +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.randomForest(df, Survived ~ Freq + Age, "classification") +#' } +#' @note spark.randomForest since 2.1.0 +setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), + function(data, formula, type = c("regression", "classification"), + maxDepth = 5, maxBins = 32, numTrees = 20, impurity = NULL, + featureSubsetStrategy = "auto", seed = NULL, subsamplingRate = 1.0, + minInstancesPerNode = 1, minInfoGain = 0.0, checkpointInterval = 10, + maxMemoryInMB = 256, cacheNodeIds = FALSE) { + type <- match.arg(type) + formula <- paste(deparse(formula), collapse = "") + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } + switch(type, + regression = { + if (is.null(impurity)) impurity <- "variance" + impurity <- match.arg(impurity, "variance") + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestRegressorWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestRegressionModel", jobj = jobj) + }, + classification = { + if (is.null(impurity)) impurity <- "gini" + impurity <- match.arg(impurity, c("gini", "entropy")) + jobj <- callJStatic("org.apache.spark.ml.r.RandomForestClassifierWrapper", + "fit", data@sdf, formula, as.integer(maxDepth), + as.integer(maxBins), as.integer(numTrees), + impurity, as.integer(minInstancesPerNode), + as.numeric(minInfoGain), as.integer(checkpointInterval), + as.character(featureSubsetStrategy), seed, + as.numeric(subsamplingRate), + as.integer(maxMemoryInMB), as.logical(cacheNodeIds)) + new("RandomForestClassificationModel", jobj = jobj) + } + ) + }) + +# Get the summary of a Random Forest Regression Model + +#' @return \code{summary} returns summary information of the fitted model, which is a list. +#' The list of components includes \code{formula} (formula), +#' \code{numFeatures} (number of features), \code{features} (list of features), +#' \code{featureImportances} (feature importances), \code{maxDepth} (max depth of trees), +#' \code{numTrees} (number of trees), and \code{treeWeights} (tree weights). +#' @rdname spark.randomForest +#' @aliases summary,RandomForestRegressionModel-method +#' @export +#' @note summary(RandomForestRegressionModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestRegressionModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestRegressionModel" + ans + }) + +# Prints the summary of Random Forest Regression Model + +#' @param x summary object of Random Forest regression model or classification model +#' returned by \code{summary}. +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestRegressionModel since 2.1.0 +print.summary.RandomForestRegressionModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Get the summary of a Random Forest Classification Model + +#' @rdname spark.randomForest +#' @aliases summary,RandomForestClassificationModel-method +#' @export +#' @note summary(RandomForestClassificationModel) since 2.1.0 +setMethod("summary", signature(object = "RandomForestClassificationModel"), + function(object) { + ans <- summary.treeEnsemble(object) + class(ans) <- "summary.RandomForestClassificationModel" + ans + }) + +# Prints the summary of Random Forest Classification Model + +#' @rdname spark.randomForest +#' @export +#' @note print.summary.RandomForestClassificationModel since 2.1.0 +print.summary.RandomForestClassificationModel <- function(x, ...) { + print.summary.treeEnsemble(x) +} + +# Makes predictions from a Random Forest Regression model or Classification model + +#' @param newData a SparkDataFrame for testing. +#' @return \code{predict} returns a SparkDataFrame containing predicted labeled in a column named +#' "prediction". +#' @rdname spark.randomForest +#' @aliases predict,RandomForestRegressionModel-method +#' @export +#' @note predict(RandomForestRegressionModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestRegressionModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +#' @rdname spark.randomForest +#' @aliases predict,RandomForestClassificationModel-method +#' @export +#' @note predict(RandomForestClassificationModel) since 2.1.0 +setMethod("predict", signature(object = "RandomForestClassificationModel"), + function(object, newData) { + predict_internal(object, newData) + }) + +# Save the Random Forest Regression or Classification model to the input path. + +#' @param object A fitted Random Forest regression model or classification model. +#' @param path The directory where the model is saved. +#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE +#' which means throw exception if the output path exists. +#' +#' @aliases write.ml,RandomForestRegressionModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestRegressionModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestRegressionModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) + +#' @aliases write.ml,RandomForestClassificationModel,character-method +#' @rdname spark.randomForest +#' @export +#' @note write.ml(RandomForestClassificationModel, character) since 2.1.0 +setMethod("write.ml", signature(object = "RandomForestClassificationModel", path = "character"), + function(object, path, overwrite = FALSE) { + write_internal(object, path, overwrite) + }) diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R new file mode 100644 index 0000000000000..5dfef8625061b --- /dev/null +++ b/R/pkg/R/mllib_utils.R @@ -0,0 +1,126 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib_utils.R: Utilities for MLlib integration + +# Integration with R's standard functions. +# Most of MLlib's argorithms are provided in two flavours: +# - a specialization of the default R methods (glm). These methods try to respect +# the inputs and the outputs of R's method to the largest extent, but some small differences +# may exist. +# - a set of methods that reflect the arguments of the other languages supported by Spark. These +# methods are prefixed with the `spark.` prefix: spark.glm, spark.kmeans, etc. + +#' Saves the MLlib model to the input path +#' +#' Saves the MLlib model to the input path. For more information, see the specific +#' MLlib model below. +#' @rdname write.ml +#' @name write.ml +#' @export +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, +#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, +#' @seealso \link{spark.lda}, \link{spark.logit}, +#' @seealso \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear}, +#' @seealso \link{read.ml} +NULL + +#' Makes predictions from a MLlib model +#' +#' Makes predictions from a MLlib model. For more information, see the specific +#' MLlib model below. +#' @rdname predict +#' @name predict +#' @export +#' @seealso \link{spark.als}, \link{spark.bisectingKmeans}, \link{spark.gaussianMixture}, +#' @seealso \link{spark.gbt}, \link{spark.glm}, \link{glm}, \link{spark.isoreg}, +#' @seealso \link{spark.kmeans}, +#' @seealso \link{spark.logit}, \link{spark.mlp}, \link{spark.naiveBayes}, +#' @seealso \link{spark.randomForest}, \link{spark.survreg}, \link{spark.svmLinear} +NULL + +write_internal <- function(object, path, overwrite = FALSE) { + writer <- callJMethod(object@jobj, "write") + if (overwrite) { + writer <- callJMethod(writer, "overwrite") + } + invisible(callJMethod(writer, "save", path)) +} + +predict_internal <- function(object, newData) { + dataFrame(callJMethod(object@jobj, "transform", newData@sdf)) +} + +#' Load a fitted MLlib model from the input path. +#' +#' @param path path of the model to read. +#' @return A fitted MLlib model. +#' @rdname read.ml +#' @name read.ml +#' @export +#' @seealso \link{write.ml} +#' @examples +#' \dontrun{ +#' path <- "path/to/model" +#' model <- read.ml(path) +#' } +#' @note read.ml since 2.0.0 +read.ml <- function(path) { + path <- suppressWarnings(normalizePath(path)) + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.ml.r.RWrappers", "session", sparkSession) + jobj <- callJStatic("org.apache.spark.ml.r.RWrappers", "load", path) + if (isInstanceOf(jobj, "org.apache.spark.ml.r.NaiveBayesWrapper")) { + new("NaiveBayesModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.AFTSurvivalRegressionWrapper")) { + new("AFTSurvivalRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper")) { + new("GeneralizedLinearRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) { + new("KMeansModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) { + new("LDAModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper")) { + new("MultilayerPerceptronClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) { + new("IsotonicRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) { + new("GaussianMixtureModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.ALSWrapper")) { + new("ALSModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LogisticRegressionWrapper")) { + new("LogisticRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestRegressorWrapper")) { + new("RandomForestRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.RandomForestClassifierWrapper")) { + new("RandomForestClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTRegressorWrapper")) { + new("GBTRegressionModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GBTClassifierWrapper")) { + new("GBTClassificationModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.BisectingKMeansWrapper")) { + new("BisectingKMeansModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) { + new("LinearSVCModel", jobj = jobj) + } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) { + new("FPGrowthModel", jobj = jobj) + } else { + stop("Unsupported model: ", jobj) + } +} diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 4dee3245f9b75..8fa21be3076b5 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -780,7 +780,7 @@ setMethod("cogroup", #' @noRd setMethod("sortByKey", signature(x = "RDD"), - function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, ascending = TRUE, numPartitions = SparkR:::getNumPartitionsRDD(x)) { rangeBounds <- list() if (numPartitions > 1) { @@ -850,7 +850,7 @@ setMethod("sortByKey", #' @noRd setMethod("subtractByKey", signature(x = "RDD", other = "RDD"), - function(x, other, numPartitions = SparkR:::getNumPartitions(x)) { + function(x, other, numPartitions = SparkR:::getNumPartitionsRDD(x)) { filterFunction <- function(elem) { iters <- elem[[2]] (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0) diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 6b4a2f2fdc85c..d0a12b7ecec65 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -323,6 +323,18 @@ sparkRHive.init <- function(jsc = NULL) { #' Additional Spark properties can be set in \code{...}, and these named parameters take priority #' over values in \code{master}, \code{appName}, named lists of \code{sparkConfig}. #' +#' When called in an interactive session, this method checks for the Spark installation, and, if not +#' found, it will be downloaded and cached automatically. Alternatively, \code{install.spark} can +#' be called manually. +#' +#' A default warehouse is created automatically in the current directory when a managed table is +#' created via \code{sql} statement \code{CREATE TABLE}, for example. To change the location of the +#' warehouse, set the named parameter \code{spark.sql.warehouse.dir} to the SparkSession. Along with +#' the warehouse, an accompanied metastore may also be automatically created in the current +#' directory when a new SparkSession is initialized with \code{enableHiveSupport} set to +#' \code{TRUE}, which is the default. For more details, refer to Hive configuration at +#' \url{http://spark.apache.org/docs/latest/sql-programming-guide.html#hive-tables}. +#' #' For details on how to initialize and use SparkR, refer to SparkR programming guide at #' \url{http://spark.apache.org/docs/latest/sparkr.html#starting-up-sparksession}. #' @@ -373,8 +385,17 @@ sparkR.session <- function( overrideEnvs(sparkConfigMap, paramMap) } + deployMode <- "" + if (exists("spark.submit.deployMode", envir = sparkConfigMap)) { + deployMode <- sparkConfigMap[["spark.submit.deployMode"]] + } + + if (!exists("spark.r.sql.derby.temp.dir", envir = sparkConfigMap)) { + sparkConfigMap[["spark.r.sql.derby.temp.dir"]] <- tempdir() + } + if (!exists(".sparkRjsc", envir = .sparkREnv)) { - retHome <- sparkCheckInstall(sparkHome, master) + retHome <- sparkCheckInstall(sparkHome, master, deployMode) if (!is.null(retHome)) sparkHome <- retHome sparkExecutorEnvMap <- new.env() sparkR.sparkContext(master, appName, sparkHome, sparkConfigMap, sparkExecutorEnvMap, @@ -402,6 +423,30 @@ sparkR.session <- function( sparkSession } +#' Get the URL of the SparkUI instance for the current active SparkSession +#' +#' Get the URL of the SparkUI instance for the current active SparkSession. +#' +#' @return the SparkUI URL, or NA if it is disabled, or not started. +#' @rdname sparkR.uiWebUrl +#' @name sparkR.uiWebUrl +#' @export +#' @examples +#'\dontrun{ +#' sparkR.session() +#' url <- sparkR.uiWebUrl() +#' } +#' @note sparkR.uiWebUrl since 2.1.1 +sparkR.uiWebUrl <- function() { + sc <- sparkR.callJMethod(getSparkContext(), "sc") + u <- callJMethod(sc, "uiWebUrl") + if (callJMethod(u, "isDefined")) { + callJMethod(u, "get") + } else { + NA + } +} + #' Assigns a group ID to all the jobs started by this thread until the group ID is set to a #' different value or cleared. #' @@ -419,7 +464,7 @@ sparkR.session <- function( #' @method setJobGroup default setJobGroup.default <- function(groupId, description, interruptOnCancel) { sc <- getSparkContext() - callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel) + invisible(callJMethod(sc, "setJobGroup", groupId, description, interruptOnCancel)) } setJobGroup <- function(sc, groupId, description, interruptOnCancel) { @@ -449,7 +494,7 @@ setJobGroup <- function(sc, groupId, description, interruptOnCancel) { #' @method clearJobGroup default clearJobGroup.default <- function() { sc <- getSparkContext() - callJMethod(sc, "clearJobGroup") + invisible(callJMethod(sc, "clearJobGroup")) } clearJobGroup <- function(sc) { @@ -476,7 +521,7 @@ clearJobGroup <- function(sc) { #' @method cancelJobGroup default cancelJobGroup.default <- function(groupId) { sc <- getSparkContext() - callJMethod(sc, "cancelJobGroup", groupId) + invisible(callJMethod(sc, "cancelJobGroup", groupId)) } cancelJobGroup <- function(sc, groupId) { @@ -550,24 +595,25 @@ processSparkPackages <- function(packages) { # # @param sparkHome directory to find Spark package. # @param master the Spark master URL, used to check local or remote mode. +# @param deployMode whether to deploy your driver on the worker nodes (cluster) +# or locally as an external client (client). # @return NULL if no need to update sparkHome, and new sparkHome otherwise. -sparkCheckInstall <- function(sparkHome, master) { +sparkCheckInstall <- function(sparkHome, master, deployMode) { if (!isSparkRShell()) { if (!is.na(file.info(sparkHome)$isdir)) { - msg <- paste0("Spark package found in SPARK_HOME: ", sparkHome) - message(msg) + message("Spark package found in SPARK_HOME: ", sparkHome) NULL } else { - if (!nzchar(master) || isMasterLocal(master)) { - msg <- paste0("Spark not found in SPARK_HOME: ", - sparkHome) - message(msg) + if (interactive() || isMasterLocal(master)) { + message("Spark not found in SPARK_HOME: ", sparkHome) packageLocalDir <- install.spark() packageLocalDir - } else { + } else if (isClientMode(master) || deployMode == "client") { msg <- paste0("Spark not found in SPARK_HOME: ", sparkHome, "\n", installInstruction("remote")) stop(msg) + } else { + NULL } } } else { diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index dcd7198f41ea7..d78a10893f92e 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -138,9 +138,9 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), collect(dataFrame(sct)) }) -#' Calculates the approximate quantiles of a numerical column of a SparkDataFrame +#' Calculates the approximate quantiles of numerical columns of a SparkDataFrame #' -#' Calculates the approximate quantiles of a numerical column of a SparkDataFrame. +#' Calculates the approximate quantiles of numerical columns of a SparkDataFrame. #' The result of this algorithm has the following deterministic bound: #' If the SparkDataFrame has N elements and if we request the quantile at probability p up to #' error err, then the algorithm will return a sample x from the SparkDataFrame so that the @@ -149,15 +149,20 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' This method implements a variation of the Greenwald-Khanna algorithm (with some speed #' optimizations). The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 #' Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. +#' Note that NA values will be ignored in numerical columns before calculation. For +#' columns only containing NA values, an empty list is returned. #' #' @param x A SparkDataFrame. -#' @param col The name of the numerical column. +#' @param cols A single column name, or a list of names for multiple columns. #' @param probabilities A list of quantile probabilities. Each number must belong to [0, 1]. #' For example 0 is the minimum, 0.5 is the median, 1 is the maximum. #' @param relativeError The relative target precision to achieve (>= 0). If set to zero, #' the exact quantiles are computed, which could be very expensive. #' Note that values greater than 1 are accepted but give the same result as 1. -#' @return The approximate quantiles at the given probabilities. +#' @return The approximate quantiles at the given probabilities. If the input is a single column name, +#' the output is a list of approximate quantiles in that column; If the input is +#' multiple column names, the output should be a list, and each element in it is a list of +#' numeric values which represents the approximate quantiles in corresponding column. #' #' @rdname approxQuantile #' @name approxQuantile @@ -171,12 +176,17 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' } #' @note approxQuantile since 2.0.0 setMethod("approxQuantile", - signature(x = "SparkDataFrame", col = "character", + signature(x = "SparkDataFrame", cols = "character", probabilities = "numeric", relativeError = "numeric"), - function(x, col, probabilities, relativeError) { + function(x, cols, probabilities, relativeError) { statFunctions <- callJMethod(x@sdf, "stat") - callJMethod(statFunctions, "approxQuantile", col, - as.list(probabilities), relativeError) + quantiles <- callJMethod(statFunctions, "approxQuantile", as.list(cols), + as.list(probabilities), relativeError) + if (length(cols) == 1) { + quantiles[[1]] + } else { + quantiles + } }) #' Returns a stratified sample without replacement diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R new file mode 100644 index 0000000000000..8390bd5e6de72 --- /dev/null +++ b/R/pkg/R/streaming.R @@ -0,0 +1,214 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# streaming.R - Structured Streaming / StreamingQuery class and methods implemented in S4 OO classes + +#' @include generics.R jobj.R +NULL + +#' S4 class that represents a StreamingQuery +#' +#' StreamingQuery can be created by using read.stream() and write.stream() +#' +#' @rdname StreamingQuery +#' @seealso \link{read.stream} +#' +#' @param ssq A Java object reference to the backing Scala StreamingQuery +#' @export +#' @note StreamingQuery since 2.2.0 +#' @note experimental +setClass("StreamingQuery", + slots = list(ssq = "jobj")) + +setMethod("initialize", "StreamingQuery", function(.Object, ssq) { + .Object@ssq <- ssq + .Object +}) + +streamingQuery <- function(ssq) { + stopifnot(class(ssq) == "jobj") + new("StreamingQuery", ssq) +} + +#' @rdname show +#' @export +#' @note show(StreamingQuery) since 2.2.0 +setMethod("show", "StreamingQuery", + function(object) { + name <- callJMethod(object@ssq, "name") + if (!is.null(name)) { + cat(paste0("StreamingQuery '", name, "'\n")) + } else { + cat("StreamingQuery", "\n") + } + }) + +#' queryName +#' +#' Returns the user-specified name of the query. This is specified in +#' \code{write.stream(df, queryName = "query")}. This name, if set, must be unique across all active +#' queries. +#' +#' @param x a StreamingQuery. +#' @return The name of the query, or NULL if not specified. +#' @rdname queryName +#' @name queryName +#' @aliases queryName,StreamingQuery-method +#' @family StreamingQuery methods +#' @seealso \link{write.stream} +#' @export +#' @examples +#' \dontrun{ queryName(sq) } +#' @note queryName(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("queryName", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "name") + }) + +#' @rdname explain +#' @name explain +#' @aliases explain,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ explain(sq) } +#' @note explain(StreamingQuery) since 2.2.0 +setMethod("explain", + signature(x = "StreamingQuery"), + function(x, extended = FALSE) { + cat(callJMethod(x@ssq, "explainInternal", extended), "\n") + }) + +#' lastProgress +#' +#' Prints the most recent progess update of this streaming query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname lastProgress +#' @name lastProgress +#' @aliases lastProgress,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ lastProgress(sq) } +#' @note lastProgress(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("lastProgress", + signature(x = "StreamingQuery"), + function(x) { + p <- callJMethod(x@ssq, "lastProgress") + if (is.null(p)) { + cat("Streaming query has no progress") + } else { + cat(callJMethod(p, "toString"), "\n") + } + }) + +#' status +#' +#' Prints the current status of the query in JSON format. +#' +#' @param x a StreamingQuery. +#' @rdname status +#' @name status +#' @aliases status,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ status(sq) } +#' @note status(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("status", + signature(x = "StreamingQuery"), + function(x) { + cat(callJMethod(callJMethod(x@ssq, "status"), "toString"), "\n") + }) + +#' isActive +#' +#' Returns TRUE if this query is actively running. +#' +#' @param x a StreamingQuery. +#' @return TRUE if query is actively running, FALSE if stopped. +#' @rdname isActive +#' @name isActive +#' @aliases isActive,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ isActive(sq) } +#' @note isActive(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("isActive", + signature(x = "StreamingQuery"), + function(x) { + callJMethod(x@ssq, "isActive") + }) + +#' awaitTermination +#' +#' Waits for the termination of the query, either by \code{stopQuery} or by an error. +#' +#' If the query has terminated, then all subsequent calls to this method will return TRUE +#' immediately. +#' +#' @param x a StreamingQuery. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. +#' @rdname awaitTermination +#' @name awaitTermination +#' @aliases awaitTermination,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ awaitTermination(sq, 10000) } +#' @note awaitTermination(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("awaitTermination", + signature(x = "StreamingQuery"), + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } + }) + +#' stopQuery +#' +#' Stops the execution of this query if it is running. This method blocks until the execution is +#' stopped. +#' +#' @param x a StreamingQuery. +#' @rdname stopQuery +#' @name stopQuery +#' @aliases stopQuery,StreamingQuery-method +#' @family StreamingQuery methods +#' @export +#' @examples +#' \dontrun{ stopQuery(sq) } +#' @note stopQuery(StreamingQuery) since 2.2.0 +#' @note experimental +setMethod("stopQuery", + signature(x = "StreamingQuery"), + function(x) { + invisible(callJMethod(x@ssq, "stop")) + }) diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index abca703617c7b..ade0f05c02542 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -29,7 +29,7 @@ PRIMITIVE_TYPES <- as.environment(list( "string" = "character", "binary" = "raw", "boolean" = "logical", - "timestamp" = "POSIXct", + "timestamp" = c("POSIXct", "POSIXt"), "date" = "Date", # following types are not SQL types returned by dtypes(). They are listed here for usage # by checkType() in schema.R. diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 20004549cc037..d29af00affb98 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -756,12 +756,17 @@ varargsToJProperties <- function(...) { props } -launchScript <- function(script, combinedArgs, capture = FALSE) { +launchScript <- function(script, combinedArgs, wait = FALSE) { if (.Platform$OS.type == "windows") { scriptWithArgs <- paste(script, combinedArgs, sep = " ") - shell(scriptWithArgs, translate = TRUE, wait = capture, intern = capture) # nolint + # on Windows, intern = F seems to mean output to the console. (documentation on this is missing) + shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait) # nolint } else { - system2(script, combinedArgs, wait = capture, stdout = capture) + # http://stat.ethz.ch/R-manual/R-devel/library/base/html/system2.html + # stdout = F means discard output + # stdout = "" means to its console (default) + # Note that the console of this child process might not be the same as the running R process. + system2(script, combinedArgs, stdout = "", wait = wait) } } @@ -777,6 +782,10 @@ isMasterLocal <- function(master) { grepl("^local(\\[([0-9]+|\\*)\\])?$", master, perl = TRUE) } +isClientMode <- function(master) { + grepl("([a-z]+)-client$", master, perl = TRUE) +} + isSparkRShell <- function() { grepl(".*shell\\.R$", Sys.getenv("R_PROFILE_USER"), perl = TRUE) } @@ -814,7 +823,16 @@ captureJVMException <- function(e, method) { stacktrace <- rawmsg } - if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { + # StreamingQueryException could wrap an IllegalArgumentException, so look for that first + if (any(grep("org.apache.spark.sql.streaming.StreamingQueryException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.streaming.StreamingQueryException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "streaming query error - ", first), call. = FALSE) + } else if (any(grep("java.lang.IllegalArgumentException: ", stacktrace))) { msg <- strsplit(stacktrace, "java.lang.IllegalArgumentException: ", fixed = TRUE)[[1]] # Extract "Error in ..." message. rmsg <- msg[1] @@ -828,6 +846,32 @@ captureJVMException <- function(e, method) { # Extract the first message of JVM exception. first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] stop(paste0(rmsg, "analysis error - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such database - ", first), call. = FALSE) + } else + if (any(grep("org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.analysis.NoSuchTableException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "no such table - ", first), call. = FALSE) + } else if (any(grep("org.apache.spark.sql.catalyst.parser.ParseException: ", stacktrace))) { + msg <- strsplit(stacktrace, "org.apache.spark.sql.catalyst.parser.ParseException: ", + fixed = TRUE)[[1]] + # Extract "Error in ..." message. + rmsg <- msg[1] + # Extract the first message of JVM exception. + first <- strsplit(msg[2], "\r?\n\tat")[[1]][1] + stop(paste0(rmsg, "parse error - ", first), call. = FALSE) } else { stop(stacktrace, call. = FALSE) } @@ -837,7 +881,7 @@ captureJVMException <- function(e, method) { # # @param inputData a list of rows, with each row a list # @return data.frame with raw columns as lists -rbindRaws <- function(inputData){ +rbindRaws <- function(inputData) { row1 <- inputData[[1]] rawcolumns <- ("raw" == sapply(row1, class)) @@ -847,3 +891,19 @@ rbindRaws <- function(inputData){ out[!rawcolumns] <- lapply(out[!rawcolumns], unlist) out } + +# Get basename without extension from URL +basenameSansExtFromUrl <- function(url) { + # split by '/' + splits <- unlist(strsplit(url, "^.+/")) + last <- tail(splits, 1) + # this is from file_path_sans_ext + # first, remove any compression extension + filename <- sub("[.](gz|bz2|xz)$", "", last) + # then, strip extension by the last '.' + sub("([^.]+)\\.[[:alnum:]]+$", "\\1", filename) +} + +isAtomicLengthOne <- function(x) { + is.atomic(x) && length(x) == 1 +} diff --git a/R/pkg/inst/tests/testthat/jarTest.R b/R/pkg/inst/tests/testthat/jarTest.R index c9615c8d4faf6..e2241e03b55f8 100644 --- a/R/pkg/inst/tests/testthat/jarTest.R +++ b/R/pkg/inst/tests/testthat/jarTest.R @@ -16,7 +16,7 @@ # library(SparkR) -sc <- sparkR.session() +sc <- sparkR.session(master = "local[1]") helloTest <- SparkR:::callJStatic("sparkrtest.DummyClass", "helloWorld", diff --git a/R/pkg/inst/tests/testthat/packageInAJarTest.R b/R/pkg/inst/tests/testthat/packageInAJarTest.R index 4bc935c79eb0f..ac706261999fb 100644 --- a/R/pkg/inst/tests/testthat/packageInAJarTest.R +++ b/R/pkg/inst/tests/testthat/packageInAJarTest.R @@ -17,7 +17,7 @@ library(SparkR) library(sparkPackageTest) -sparkR.session() +sparkR.session(master = "local[1]") run1 <- myfunc(5L) diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa85..6e160fae1afed 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -17,9 +17,11 @@ context("SerDe functionality") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("SerDe of primitive types", { + skip_on_cran() + x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -38,6 +40,8 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { + skip_on_cran() + x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -65,6 +69,8 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { + skip_on_cran() + x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 8813e18a1fa4d..919b063bf0693 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -17,10 +17,13 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { + skip_on_cran() + if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } - testOutput <- launchScript("ECHO", "a/b/c", capture = TRUE) + + testOutput <- launchScript("ECHO", "a/b/c", wait = TRUE) abcPath <- testOutput[1] expect_equal(abcPath, "a\\b\\c") }) diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b5c279e3156e5..00954fa31b0ee 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -18,12 +18,14 @@ context("functions on binary files") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -38,6 +40,8 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -50,6 +54,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -74,6 +80,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 59cb2e6204405..236cb3885445e 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -18,7 +18,7 @@ context("binary functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -29,6 +29,8 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { + skip_on_cran() + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -51,6 +53,8 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -69,6 +73,8 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 65f204d096f43..2c96740df77bb 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -18,7 +18,7 @@ context("broadcast variables") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data @@ -26,8 +26,10 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - randomMatBr <- broadcast(sc, randomMat) + randomMatBr <- broadcastRDD(sc, randomMat) useBroadcast <- function(x) { sum(SparkR:::value(randomMatBr) * x) @@ -38,6 +40,8 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index 0cf25fe1dbf39..3d53bebab6300 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -18,6 +18,8 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -26,16 +28,22 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { + skip_on_cran() + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index caca06933952b..f6d9f5423df02 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -18,13 +18,15 @@ context("test functions in sparkR.R") test_that("Check masked functions", { + skip_on_cran() + # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. - namesOfMaskedCompletely <- c("cov", "filter", "sample") + namesOfMaskedCompletely <- c("cov", "filter", "sample", "not") namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window", "as.data.frame", "union") + "summary", "transform", "drop", "window", "as.data.frame", "union", "not") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) } @@ -55,8 +57,10 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { + skip_on_cran() + for (i in 1:4) { - sc <- suppressWarnings(sparkR.init()) + sc <- suppressWarnings(sparkR.init(master = sparkRTestMaster)) rdd <- parallelize(sc, 1:20, 2L) expect_equal(countRDD(rdd), 20) suppressWarnings(sparkR.stop()) @@ -65,7 +69,7 @@ test_that("repeatedly starting and stopping SparkR", { test_that("repeatedly starting and stopping SparkSession", { for (i in 1:4) { - sparkR.session(enableHiveSupport = FALSE) + sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) df <- createDataFrame(data.frame(dummy = 1:i)) expect_equal(count(df), i) sparkR.session.stop() @@ -73,12 +77,14 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { - sc <- sparkR.sparkContext() # sc should get id 0 + skip_on_cran() + + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 sparkR.session.stop() - sc <- sparkR.sparkContext() # sc should get id 0 again + sc <- sparkR.sparkContext(master = sparkRTestMaster) # sc should get id 0 again # GC rdd1 before creating rdd3 and rdd2 after rm(rdd1) @@ -96,7 +102,9 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { - sc <- sparkR.sparkContext() + skip_on_cran() + + sc <- sparkR.sparkContext(master = sparkRTestMaster) setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") clearJobGroup() @@ -108,12 +116,16 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { - sparkR.sparkContext() + skip_on_cran() + + sparkR.sparkContext(master = sparkRTestMaster) setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + skip_on_cran() + e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -141,6 +153,8 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { + skip_on_cran() + expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -161,14 +175,16 @@ test_that("sparkJars sparkPackages as comma-separated strings", { }) test_that("spark.lapply should perform simple transforms", { - sparkR.sparkContext() + sparkR.sparkContext(master = sparkRTestMaster) doubled <- spark.lapply(1:10, function(x) { 2 * x }) expect_equal(doubled, as.list(2 * 1:10)) sparkR.session.stop() }) test_that("add and get file to be downloaded with Spark job on every node", { - sparkR.sparkContext() + skip_on_cran() + + sparkR.sparkContext(master = sparkRTestMaster) # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") filename <- basename(path) @@ -177,6 +193,13 @@ test_that("add and get file to be downloaded with Spark job on every node", { spark.addFile(path) download_path <- spark.getSparkFiles(filename) expect_equal(readLines(download_path), words) + + # Test spark.getSparkFiles works well on executors. + seq <- seq(from = 1, to = 10, length.out = 5) + f <- function(seq) { spark.getSparkFiles(filename) } + results <- spark.lapply(seq, f) + for (i in 1:5) { expect_equal(basename(results[[i]]), filename) } + unlink(path) # Test add directory recursively. diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 563ea298c2dd8..d7d9eeed1575e 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -18,7 +18,7 @@ context("include R packages") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Partitioned data @@ -26,6 +26,8 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -42,6 +44,8 @@ test_that("include inside function", { }) test_that("use include package", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_jvm_api.R b/R/pkg/inst/tests/testthat/test_jvm_api.R index 7348c893d0af3..8b3b4f73de170 100644 --- a/R/pkg/inst/tests/testthat/test_jvm_api.R +++ b/R/pkg/inst/tests/testthat/test_jvm_api.R @@ -17,7 +17,7 @@ context("JVM API") -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) test_that("Create and call methods on object", { jarr <- sparkR.newJObject("java.util.ArrayList") diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R deleted file mode 100644 index db98d0e45547e..0000000000000 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ /dev/null @@ -1,942 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -library(testthat) - -context("MLlib functions") - -# Tests for MLlib functions in SparkR -sparkSession <- sparkR.session(enableHiveSupport = FALSE) - -absoluteSparkPath <- function(x) { - sparkHome <- sparkR.conf("spark.home") - file.path(sparkHome, x) -} - -test_that("formula of spark.glm", { - training <- suppressWarnings(createDataFrame(iris)) - # directly calling the spark API - # dot minus and intercept vs native glm - model <- spark.glm(training, Sepal_Width ~ . - Species + 0) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # feature interaction vs native glm - model <- spark.glm(training, Sepal_Width ~ Species:Sepal_Length) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # glm should work with long formula - training <- suppressWarnings(createDataFrame(iris)) - training$LongLongLongLongLongName <- training$Sepal_Width - training$VeryLongLongLongLonLongName <- training$Sepal_Length - training$AnotherLongLongLongLongName <- training$Species - model <- spark.glm(training, LongLongLongLongLongName ~ VeryLongLongLongLonLongName + - AnotherLongLongLongLongName) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("spark.glm and predict", { - training <- suppressWarnings(createDataFrame(iris)) - # gaussian family - model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # poisson family - model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, - family = poisson(link = identity)) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, - data = iris, family = poisson(link = identity)), iris)) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # Test stats::predict is working - x <- rnorm(15) - y <- x + rnorm(15) - expect_equal(length(predict(lm(y ~ x))), 15) -}) - -test_that("spark.glm summary", { - # gaussian family - training <- suppressWarnings(createDataFrame(iris)) - stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) - - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - out <- capture.output(print(stats)) - expect_match(out[2], "Deviance Residuals:") - expect_true(any(grepl("AIC: 59.22", out))) - - # binomial family - df <- suppressWarnings(createDataFrame(iris)) - training <- df[df$Species %in% c("versicolor", "virginica"), ] - stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, - family = binomial(link = "logit"))) - - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] - rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit"))) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # Test spark.glm works with weighted dataset - a1 <- c(0, 1, 2, 3) - a2 <- c(5, 2, 1, 3) - w <- c(1, 2, 3, 4) - b <- c(1, 0, 1, 0) - data <- as.data.frame(cbind(a1, a2, w, b)) - df <- suppressWarnings(createDataFrame(data)) - - stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) - rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-3)) - expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # Test summary works on base GLM models - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) - baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) - - # Test spark.glm works with regularization parameter - data <- as.data.frame(cbind(a1, a2, b)) - df <- suppressWarnings(createDataFrame(data)) - regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) - expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result -}) - -test_that("spark.glm save/load", { - training <- suppressWarnings(createDataFrame(iris)) - m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) - s <- summary(m) - - modelPath <- tempfile(pattern = "spark-glm", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - - expect_equal(s$coefficients, s2$coefficients) - expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) - expect_equal(s$dispersion, s2$dispersion) - expect_equal(s$null.deviance, s2$null.deviance) - expect_equal(s$deviance, s2$deviance) - expect_equal(s$df.null, s2$df.null) - expect_equal(s$df.residual, s2$df.residual) - expect_equal(s$aic, s2$aic) - expect_equal(s$iter, s2$iter) - expect_true(!s$is.loaded) - expect_true(s2$is.loaded) - - unlink(modelPath) -}) - - - -test_that("formula of glm", { - training <- suppressWarnings(createDataFrame(iris)) - # dot minus and intercept vs native glm - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # feature interaction vs native glm - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # glm should work with long formula - training <- suppressWarnings(createDataFrame(iris)) - training$LongLongLongLongLongName <- training$Sepal_Width - training$VeryLongLongLongLonLongName <- training$Sepal_Length - training$AnotherLongLongLongLongName <- training$Species - model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, - data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("glm and predict", { - training <- suppressWarnings(createDataFrame(iris)) - # gaussian family - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # poisson family - model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, - family = poisson(link = identity)) - prediction <- predict(model, training) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") - vals <- collect(select(prediction, "prediction")) - rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, - data = iris, family = poisson(link = identity)), iris)) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - - # Test stats::predict is working - x <- rnorm(15) - y <- x + rnorm(15) - expect_equal(length(predict(lm(y ~ x))), 15) -}) - -test_that("glm summary", { - # gaussian family - training <- suppressWarnings(createDataFrame(iris)) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) - - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # binomial family - df <- suppressWarnings(createDataFrame(iris)) - training <- df[df$Species %in% c("versicolor", "virginica"), ] - stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, - family = binomial(link = "logit"))) - - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] - rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit"))) - - coefs <- unlist(stats$coefficients) - rCoefs <- unlist(rStats$coefficients) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) - expect_equal(stats$dispersion, rStats$dispersion) - expect_equal(stats$null.deviance, rStats$null.deviance) - expect_equal(stats$deviance, rStats$deviance) - expect_equal(stats$df.null, rStats$df.null) - expect_equal(stats$df.residual, rStats$df.residual) - expect_equal(stats$aic, rStats$aic) - - # Test summary works on base GLM models - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) - baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) -}) - -test_that("glm save/load", { - training <- suppressWarnings(createDataFrame(iris)) - m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - s <- summary(m) - - modelPath <- tempfile(pattern = "glm", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - - expect_equal(s$coefficients, s2$coefficients) - expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) - expect_equal(s$dispersion, s2$dispersion) - expect_equal(s$null.deviance, s2$null.deviance) - expect_equal(s$deviance, s2$deviance) - expect_equal(s$df.null, s2$df.null) - expect_equal(s$df.residual, s2$df.residual) - expect_equal(s$aic, s2$aic) - expect_equal(s$iter, s2$iter) - expect_true(!s$is.loaded) - expect_true(s2$is.loaded) - - unlink(modelPath) -}) - -test_that("spark.kmeans", { - newIris <- iris - newIris$Species <- NULL - training <- suppressWarnings(createDataFrame(newIris)) - - take(training, 1) - - model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") - sample <- take(select(predict(model, training), "prediction"), 1) - expect_equal(typeof(sample$prediction), "integer") - expect_equal(sample$prediction, 1) - - # Test stats::kmeans is working - statsModel <- kmeans(x = newIris, centers = 2) - expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) - - # Test fitted works on KMeans - fitted.model <- fitted(model) - expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) - - # Test summary works on KMeans - summary.model <- summary(model) - cluster <- summary.model$cluster - expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) - expect_equal(summary.model$coefficients, summary2$coefficients) - expect_true(!summary.model$is.loaded) - expect_true(summary2$is.loaded) - - unlink(modelPath) -}) - -test_that("spark.mlp", { - df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), - source = "libsvm") - model <- spark.mlp(df, blockSize = 128, layers = c(4, 5, 4, 3), solver = "l-bfgs", maxIter = 100, - tol = 0.5, stepSize = 1, seed = 1) - - # Test summary method - summary <- summary(model) - expect_equal(summary$labelCount, 3) - expect_equal(summary$layers, c(4, 5, 4, 3)) - expect_equal(length(summary$weights), 64) - expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), - tolerance = 1e-6) - - # Test predict method - mlpTestDF <- df - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 6), c(0, 1, 1, 1, 1, 1)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - summary2 <- summary(model2) - - expect_equal(summary2$labelCount, 3) - expect_equal(summary2$layers, c(4, 5, 4, 3)) - expect_equal(length(summary2$weights), 64) - - unlink(modelPath) - - # Test default parameter - model <- spark.mlp(df, layers = c(4, 5, 4, 3)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 0)) - - # Test illegal parameter - expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") - expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.") - - # Test random seed - # default seed - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1)) - # seed equals 10 - model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) - - # test initialWeights - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = - c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) - - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1)) - - model <- spark.mlp(df, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 0, 2, 1, 0, 0, 1)) -}) - -test_that("spark.naiveBayes", { - # R code to reproduce the result. - # We do not support instance weights yet. So we ignore the frequencies. - # - #' library(e1071) - #' t <- as.data.frame(Titanic) - #' t1 <- t[t$Freq > 0, -5] - #' m <- naiveBayes(Survived ~ ., data = t1) - #' m - #' predict(m, t1) - # - # -- output of 'm' - # - # A-priori probabilities: - # Y - # No Yes - # 0.4166667 0.5833333 - # - # Conditional probabilities: - # Class - # Y 1st 2nd 3rd Crew - # No 0.2000000 0.2000000 0.4000000 0.2000000 - # Yes 0.2857143 0.2857143 0.2857143 0.1428571 - # - # Sex - # Y Male Female - # No 0.5 0.5 - # Yes 0.5 0.5 - # - # Age - # Y Child Adult - # No 0.2000000 0.8000000 - # Yes 0.4285714 0.5714286 - # - # -- output of 'predict(m, t1)' - # - # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No - # - - t <- as.data.frame(Titanic) - t1 <- t[t$Freq > 0, -5] - df <- suppressWarnings(createDataFrame(t1)) - m <- spark.naiveBayes(df, Survived ~ ., smoothing = 0.0) - s <- summary(m) - expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) - expect_equal(sum(s$apriori), 1) - expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) - p <- collect(select(predict(m, df), "prediction")) - expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", - "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", - "Yes", "Yes", "No", "No")) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") - write.ml(m, modelPath) - expect_error(write.ml(m, modelPath)) - write.ml(m, modelPath, overwrite = TRUE) - m2 <- read.ml(modelPath) - s2 <- summary(m2) - expect_equal(s$apriori, s2$apriori) - expect_equal(s$tables, s2$tables) - - unlink(modelPath) - - # Test e1071::naiveBayes - if (requireNamespace("e1071", quietly = TRUE)) { - expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) - expect_equal(as.character(predict(m, t1[1, ])), "Yes") - } - - # Test numeric response variable - t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) - t2 <- t1[-4] - df <- suppressWarnings(createDataFrame(t2)) - m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) - s <- summary(m) - expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) - expect_equal(sum(s$apriori), 1) - expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) -}) - -test_that("spark.survreg", { - # R code to reproduce the result. - # - #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), - #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) - #' library(survival) - #' model <- survreg(Surv(time, status) ~ x + sex, rData) - #' summary(model) - #' predict(model, data) - # - # -- output of 'summary(model)' - # - # Value Std. Error z p - # (Intercept) 1.315 0.270 4.88 1.07e-06 - # x -0.190 0.173 -1.10 2.72e-01 - # sex -0.253 0.329 -0.77 4.42e-01 - # Log(scale) -1.160 0.396 -2.93 3.41e-03 - # - # -- output of 'predict(model, data)' - # - # 1 2 3 4 5 6 7 - # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269 - # - data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), - list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) - df <- createDataFrame(data, c("time", "status", "x", "sex")) - model <- spark.survreg(df, Surv(time, status) ~ x + sex) - stats <- summary(model) - coefs <- as.vector(stats$coefficients[, 1]) - rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800) - expect_equal(coefs, rCoefs, tolerance = 1e-4) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "x", "sex", "Log(scale)"))) - p <- collect(select(predict(model, df), "prediction")) - expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, - 2.390146, 2.891269, 2.891269), tolerance = 1e-4) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - coefs2 <- as.vector(stats2$coefficients[, 1]) - expect_equal(coefs, coefs2) - expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) - - unlink(modelPath) - - # Test survival::survreg - if (requireNamespace("survival", quietly = TRUE)) { - rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), - x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) - expect_error( - model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), - NA) - expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) - } -}) - -test_that("spark.isotonicRegression", { - label <- c(7.0, 5.0, 3.0, 5.0, 1.0) - feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) - weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) - data <- as.data.frame(cbind(label, feature, weight)) - df <- suppressWarnings(createDataFrame(data)) - - model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, - weightCol = "weight") - # only allow one variable on the right hand side of the formula - expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE)) - result <- summary(model) - expect_equal(result$predictions, list(7, 5, 4, 4, 1)) - - # Test model prediction - predict_data <- list(list(-2.0), list(-1.0), list(0.5), - list(0.75), list(1.0), list(2.0), list(9.0)) - predict_df <- createDataFrame(predict_data, c("feature")) - predict_result <- collect(select(predict(model, predict_df), "prediction")) - expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-isotonicRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - expect_equal(result, summary(model2)) - - unlink(modelPath) -}) - -test_that("spark.logit", { - # test binary logistic regression - label <- c(1.0, 1.0, 1.0, 0.0, 0.0) - feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) - binary_data <- as.data.frame(cbind(label, feature)) - binary_df <- createDataFrame(binary_data) - - blr_model <- spark.logit(binary_df, label ~ feature, thresholds = 1.0) - blr_predict <- collect(select(predict(blr_model, binary_df), "prediction")) - expect_equal(blr_predict$prediction, c(0, 0, 0, 0, 0)) - blr_model1 <- spark.logit(binary_df, label ~ feature, thresholds = 0.0) - blr_predict1 <- collect(select(predict(blr_model1, binary_df), "prediction")) - expect_equal(blr_predict1$prediction, c(1, 1, 1, 1, 1)) - - # test summary of binary logistic regression - blr_summary <- summary(blr_model) - blr_fmeasure <- collect(select(blr_summary$fMeasureByThreshold, "threshold", "F-Measure")) - expect_equal(blr_fmeasure$threshold, c(0.8221347, 0.7884005, 0.6674709, 0.3785437, 0.3434487), - tolerance = 1e-4) - expect_equal(blr_fmeasure$"F-Measure", c(0.5000000, 0.8000000, 0.6666667, 0.8571429, 0.7500000), - tolerance = 1e-4) - blr_precision <- collect(select(blr_summary$precisionByThreshold, "threshold", "precision")) - expect_equal(blr_precision$precision, c(1.0000000, 1.0000000, 0.6666667, 0.7500000, 0.6000000), - tolerance = 1e-4) - blr_recall <- collect(select(blr_summary$recallByThreshold, "threshold", "recall")) - expect_equal(blr_recall$recall, c(0.3333333, 0.6666667, 0.6666667, 1.0000000, 1.0000000), - tolerance = 1e-4) - - # test model save and read - modelPath <- tempfile(pattern = "spark-logisticRegression", fileext = ".tmp") - write.ml(blr_model, modelPath) - expect_error(write.ml(blr_model, modelPath)) - write.ml(blr_model, modelPath, overwrite = TRUE) - blr_model2 <- read.ml(modelPath) - blr_predict2 <- collect(select(predict(blr_model2, binary_df), "prediction")) - expect_equal(blr_predict$prediction, blr_predict2$prediction) - expect_error(summary(blr_model2)) - unlink(modelPath) - - # test multinomial logistic regression - label <- c(0.0, 1.0, 2.0, 0.0, 0.0) - feature1 <- c(4.845940, 5.64480, 7.430381, 6.464263, 5.555667) - feature2 <- c(2.941319, 2.614812, 2.162451, 3.339474, 2.970987) - feature3 <- c(1.322733, 1.348044, 3.861237, 9.686976, 3.447130) - feature4 <- c(1.3246388, 0.5510444, 0.9225810, 1.2147881, 1.6020842) - data <- as.data.frame(cbind(label, feature1, feature2, feature3, feature4)) - df <- createDataFrame(data) - - model <- spark.logit(df, label ~., family = "multinomial", thresholds = c(0, 1, 1)) - predict1 <- collect(select(predict(model, df), "prediction")) - expect_equal(predict1$prediction, c(0, 0, 0, 0, 0)) - # Summary of multinomial logistic regression is not implemented yet - expect_error(summary(model)) -}) - -test_that("spark.gaussianMixture", { - # R code to reproduce the result. - # nolint start - #' library(mvtnorm) - #' set.seed(1) - #' a <- rmvnorm(7, c(0, 0)) - #' b <- rmvnorm(8, c(10, 10)) - #' data <- rbind(a, b) - #' model <- mvnormalmixEM(data, k = 2) - #' model$lambda - # - # [1] 0.4666667 0.5333333 - # - #' model$mu - # - # [1] 0.11731091 -0.06192351 - # [1] 10.363673 9.897081 - # - #' model$sigma - # - # [[1]] - # [,1] [,2] - # [1,] 0.62049934 0.06880802 - # [2,] 0.06880802 1.27431874 - # - # [[2]] - # [,1] [,2] - # [1,] 0.2961543 0.160783 - # [2,] 0.1607830 1.008878 - # nolint end - data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808), - list(0.3295078, -0.8204684), list(0.4874291, 0.7383247), - list(0.5757814, -0.3053884), list(1.5117812, 0.3898432), - list(-0.6212406, -2.2146999), list(11.1249309, 9.9550664), - list(9.9838097, 10.9438362), list(10.8212212, 10.5939013), - list(10.9189774, 10.7821363), list(10.0745650, 8.0106483), - list(10.6198257, 9.9438713), list(9.8442045, 8.5292476), - list(9.5218499, 10.4179416)) - df <- createDataFrame(data, c("x1", "x2")) - model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2) - stats <- summary(model) - rLambda <- c(0.4666667, 0.5333333) - rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081) - rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874, - 0.2961543, 0.160783, 0.1607830, 1.008878) - expect_equal(stats$lambda, rLambda, tolerance = 1e-3) - expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3) - expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3) - p <- collect(select(predict(model, df), "prediction")) - expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$lambda, stats2$lambda) - expect_equal(unlist(stats$mu), unlist(stats2$mu)) - expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) - - unlink(modelPath) -}) - -test_that("spark.lda with libsvm", { - text <- read.df(absoluteSparkPath("data/mllib/sample_lda_libsvm_data.txt"), source = "libsvm") - model <- spark.lda(text, optimizer = "em") - - stats <- summary(model, 10) - isDistributed <- stats$isDistributed - logLikelihood <- stats$logLikelihood - logPerplexity <- stats$logPerplexity - vocabSize <- stats$vocabSize - topics <- stats$topicTopTerms - weights <- stats$topicTopTermsWeights - vocabulary <- stats$vocabulary - - expect_false(isDistributed) - expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) - expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) - expect_equal(vocabSize, 11) - expect_true(is.null(vocabulary)) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_false(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_equal(vocabulary, stats2$vocabulary) - - unlink(modelPath) -}) - -test_that("spark.lda with text input", { - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) - model <- spark.lda(text, optimizer = "online", features = "value") - - stats <- summary(model) - isDistributed <- stats$isDistributed - logLikelihood <- stats$logLikelihood - logPerplexity <- stats$logPerplexity - vocabSize <- stats$vocabSize - topics <- stats$topicTopTerms - weights <- stats$topicTopTermsWeights - vocabulary <- stats$vocabulary - - expect_false(isDistributed) - expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) - expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) - expect_equal(vocabSize, 10) - expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - - expect_false(stats2$isDistributed) - expect_equal(logLikelihood, stats2$logLikelihood) - expect_equal(logPerplexity, stats2$logPerplexity) - expect_equal(vocabSize, stats2$vocabSize) - expect_true(all.equal(vocabulary, stats2$vocabulary)) - - unlink(modelPath) -}) - -test_that("spark.posterior and spark.perplexity", { - text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) - model <- spark.lda(text, features = "value", k = 3) - - # Assert perplexities are equal - stats <- summary(model) - logPerplexity <- spark.perplexity(model, text) - expect_equal(logPerplexity, stats$logPerplexity) - - # Assert the sum of every topic distribution is equal to 1 - posterior <- spark.posterior(model, text) - local.posterior <- collect(posterior)$topicDistribution - expect_equal(length(local.posterior), sum(unlist(local.posterior))) -}) - -test_that("spark.als", { - data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), - list(2, 1, 1.0), list(2, 2, 5.0)) - df <- createDataFrame(data, c("user", "item", "score")) - model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", - rank = 10, maxIter = 5, seed = 0, reg = 0.1) - stats <- summary(model) - expect_equal(stats$rank, 10) - test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) - predictions <- collect(predict(model, test)) - - expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409), - tolerance = 1e-4) - - # Test model save/load - modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats2$rating, "score") - userFactors <- collect(stats$userFactors) - itemFactors <- collect(stats$itemFactors) - userFactors2 <- collect(stats2$userFactors) - itemFactors2 <- collect(stats2$itemFactors) - - orderUser <- order(userFactors$id) - orderUser2 <- order(userFactors2$id) - expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) - expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) - - orderItem <- order(itemFactors$id) - orderItem2 <- order(itemFactors2$id) - expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) - expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) - - unlink(modelPath) -}) - -test_that("spark.kstest", { - data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) - df <- createDataFrame(data) - testResult <- spark.kstest(df, "test", "norm") - stats <- summary(testResult) - - rStats <- ks.test(data$test, "pnorm", alternative = "two.sided") - - expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) - expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) - expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") - - testResult <- spark.kstest(df, "test", "norm", -0.5) - stats <- summary(testResult) - - rStats <- ks.test(data$test, "pnorm", -0.5, 1, alternative = "two.sided") - - expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) - expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) - expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") -}) - -test_that("spark.randomForest Regression", { - data <- suppressWarnings(createDataFrame(longley)) - model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, - numTrees = 1) - - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, - 63.221, 63.639, 64.989, 63.761, - 66.019, 67.857, 68.169, 66.513, - 68.655, 69.564, 69.331, 70.551), - tolerance = 1e-4) - - stats <- summary(model) - expect_equal(stats$numTrees, 1) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - - model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, - numTrees = 20, seed = 123) - predictions <- collect(predict(model, data)) - expect_equal(predictions$prediction, c(60.379, 61.096, 60.636, 62.258, - 63.736, 64.296, 64.868, 64.300, - 66.709, 67.697, 67.966, 67.252, - 68.866, 69.593, 69.195, 69.658), - tolerance = 1e-4) - stats <- summary(model) - expect_equal(stats$numTrees, 20) - - modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$formula, stats2$formula) - expect_equal(stats$numFeatures, stats2$numFeatures) - expect_equal(stats$features, stats2$features) - expect_equal(stats$featureImportances, stats2$featureImportances) - expect_equal(stats$numTrees, stats2$numTrees) - expect_equal(stats$treeWeights, stats2$treeWeights) - - unlink(modelPath) -}) - -test_that("spark.randomForest Classification", { - data <- suppressWarnings(createDataFrame(iris)) - model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", - maxDepth = 5, maxBins = 16) - - stats <- summary(model) - expect_equal(stats$numFeatures, 2) - expect_equal(stats$numTrees, 20) - expect_error(capture.output(stats), NA) - expect_true(length(capture.output(stats)) > 6) - - modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") - write.ml(model, modelPath) - expect_error(write.ml(model, modelPath)) - write.ml(model, modelPath, overwrite = TRUE) - model2 <- read.ml(modelPath) - stats2 <- summary(model2) - expect_equal(stats$depth, stats2$depth) - expect_equal(stats$numNodes, stats2$numNodes) - expect_equal(stats$numClasses, stats2$numClasses) - - unlink(modelPath) -}) - -sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R new file mode 100644 index 0000000000000..f3eaeb381afc4 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -0,0 +1,385 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib classification algorithms, except for tree-based algorithms") + +# Tests for MLlib classification algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.svmLinear", { + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.svmLinear(training, Species ~ ., regParam = 0.01, maxIter = 10) + summary <- summary(model) + + # test summary coefficients return matrix type + expect_true(class(summary$coefficients) == "matrix") + expect_true(class(summary$coefficients[, 1]) == "numeric") + + coefs <- summary$coefficients[, "Estimate"] + expected_coefs <- c(-0.1563083, -0.460648, 0.2276626, 1.055085) + expect_true(all(abs(coefs - expected_coefs) < 0.1)) + expect_equal(summary$intercept, -0.06004978, tolerance = 1e-2) + + # Test prediction with string label + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("versicolor", "versicolor", "versicolor", "virginica", "virginica", + "virginica", "virginica", "virginica", "virginica", "virginica") + expect_equal(sort(as.list(take(select(prediction, "prediction"), 10))[[1]]), expected) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-svm-linear", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) + df <- createDataFrame(data) + model <- spark.svmLinear(df, label ~ feature, regParam = 0.1) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + +}) + +test_that("spark.logit", { + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris.x = as.matrix(iris[, 1:4]) + #' iris.y = as.factor(as.character(iris[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $setosa + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.0981324 + # Sepal.Length -0.2909860 + # Sepal.Width 0.5510907 + # Petal.Length -0.1915217 + # Petal.Width -0.4211946 + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 1.520061e+00 + # Sepal.Length 2.524501e-02 + # Sepal.Width -5.310313e-01 + # Petal.Length 3.656543e-02 + # Petal.Width -3.144464e-05 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -2.61819385 + # Sepal.Length 0.26574097 + # Sepal.Width -0.02005932 + # Petal.Length 0.15495629 + # Petal.Width 0.42122607 + # nolint end + + # Test multinomial logistic regression againt three classes + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.logit(df, Species ~ ., regParam = 0.5) + summary <- summary(model) + + # test summary coefficients return matrix type + expect_true(class(summary$coefficients) == "matrix") + expect_true(class(summary$coefficients[, 1]) == "numeric") + + versicolorCoefsR <- c(1.52, 0.03, -0.53, 0.04, 0.00) + virginicaCoefsR <- c(-2.62, 0.27, -0.02, 0.16, 0.42) + setosaCoefsR <- c(1.10, -0.29, 0.55, -0.19, -0.42) + versicolorCoefs <- summary$coefficients[, "versicolor"] + virginicaCoefs <- summary$coefficients[, "virginica"] + setosaCoefs <- summary$coefficients[, "setosa"] + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + expect_true(all(abs(setosaCoefs - setosaCoefs) < 0.1)) + + # Test model save and load + modelPath <- tempfile(pattern = "spark-logit", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + coefs <- summary(model)$coefficients + coefs2 <- summary(model2)$coefficients + expect_equal(coefs, coefs2) + unlink(modelPath) + + # R code to reproduce the result. + # nolint start + #' library(glmnet) + #' iris2 <- iris[iris$Species %in% c("versicolor", "virginica"), ] + #' iris.x = as.matrix(iris2[, 1:4]) + #' iris.y = as.factor(as.character(iris2[, 5])) + #' logit = glmnet(iris.x, iris.y, family="multinomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # $versicolor + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # 3.93844796 + # Sepal.Length -0.13538675 + # Sepal.Width -0.02386443 + # Petal.Length -0.35076451 + # Petal.Width -0.77971954 + # + # $virginica + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # -3.93844796 + # Sepal.Length 0.13538675 + # Sepal.Width 0.02386443 + # Petal.Length 0.35076451 + # Petal.Width 0.77971954 + # + #' logit = glmnet(iris.x, iris.y, family="binomial", alpha=0, lambda=0.5) + #' coef(logit) + # + # 5 x 1 sparse Matrix of class "dgCMatrix" + # s0 + # (Intercept) -6.0824412 + # Sepal.Length 0.2458260 + # Sepal.Width 0.1642093 + # Petal.Length 0.4759487 + # Petal.Width 1.0383948 + # + # nolint end + + # Test multinomial logistic regression againt two classes + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + model <- spark.logit(training, Species ~ ., regParam = 0.5, family = "multinomial") + summary <- summary(model) + versicolorCoefsR <- c(3.94, -0.16, -0.02, -0.35, -0.78) + virginicaCoefsR <- c(-3.94, 0.16, -0.02, 0.35, 0.78) + versicolorCoefs <- summary$coefficients[, "versicolor"] + virginicaCoefs <- summary$coefficients[, "virginica"] + expect_true(all(abs(versicolorCoefsR - versicolorCoefs) < 0.1)) + expect_true(all(abs(virginicaCoefsR - virginicaCoefs) < 0.1)) + + # Test binomial logistic regression againt two classes + model <- spark.logit(training, Species ~ ., regParam = 0.5) + summary <- summary(model) + coefsR <- c(-6.08, 0.25, 0.16, 0.48, 1.04) + coefs <- summary$coefficients[, "Estimate"] + expect_true(all(abs(coefsR - coefs) < 0.1)) + + # Test prediction with string label + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "character") + expected <- c("versicolor", "versicolor", "virginica", "versicolor", "versicolor", + "versicolor", "versicolor", "versicolor", "versicolor", "versicolor") + expect_equal(as.list(take(select(prediction, "prediction"), 10))[[1]], expected) + + # Test prediction with numeric label + label <- c(0.0, 0.0, 0.0, 1.0, 1.0) + feature <- c(1.1419053, 0.9194079, -0.9498666, -1.1069903, 0.2809776) + data <- as.data.frame(cbind(label, feature)) + df <- createDataFrame(data) + model <- spark.logit(df, label ~ feature) + prediction <- collect(select(predict(model, df), "prediction")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + + # Test prediction with weightCol + weight <- c(2.0, 2.0, 2.0, 1.0, 1.0) + data2 <- as.data.frame(cbind(label, feature, weight)) + df2 <- createDataFrame(data2) + model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") + prediction2 <- collect(select(predict(model2, df2), "prediction")) + expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) +}) + +test_that("spark.mlp", { + df <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.mlp(df, label ~ features, blockSize = 128, layers = c(4, 5, 4, 3), + solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1) + + # Test summary method + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 5, 4, 3)) + expect_equal(length(summary$weights), 64) + expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), + tolerance = 1e-6) + + # Test predict method + mlpTestDF <- df + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", "0.0", "0.0", "0.0")) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-mlp", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + + expect_equal(summary2$numOfInputs, 4) + expect_equal(summary2$numOfOutputs, 3) + expect_equal(summary2$layers, c(4, 5, 4, 3)) + expect_equal(length(summary2$weights), 64) + + unlink(modelPath) + + # Test default parameter + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + # Test illegal parameter + expect_error(spark.mlp(df, label ~ features, layers = NULL), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c()), + "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, label ~ features, layers = c(3)), + "layers must be a integer vector with length > 1.") + + # Test random seed + # default seed + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + # seed equals 10 + model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + # test initialWeights + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = + c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) + + # Test formula works well + df <- suppressWarnings(createDataFrame(iris)) + model <- spark.mlp(df, Species ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, + layers = c(4, 3)) + summary <- summary(model) + expect_equal(summary$numOfInputs, 4) + expect_equal(summary$numOfOutputs, 3) + expect_equal(summary$layers, c(4, 3)) + expect_equal(length(summary$weights), 15) +}) + +test_that("spark.naiveBayes", { + # R code to reproduce the result. + # We do not support instance weights yet. So we ignore the frequencies. + # + #' library(e1071) + #' t <- as.data.frame(Titanic) + #' t1 <- t[t$Freq > 0, -5] + #' m <- naiveBayes(Survived ~ ., data = t1) + #' m + #' predict(m, t1) + # + # -- output of 'm' + # + # A-priori probabilities: + # Y + # No Yes + # 0.4166667 0.5833333 + # + # Conditional probabilities: + # Class + # Y 1st 2nd 3rd Crew + # No 0.2000000 0.2000000 0.4000000 0.2000000 + # Yes 0.2857143 0.2857143 0.2857143 0.1428571 + # + # Sex + # Y Male Female + # No 0.5 0.5 + # Yes 0.5 0.5 + # + # Age + # Y Child Adult + # No 0.2000000 0.8000000 + # Yes 0.4285714 0.5714286 + # + # -- output of 'predict(m, t1)' + # + # Yes Yes Yes Yes No No Yes Yes No No Yes Yes Yes Yes Yes Yes Yes Yes No No Yes Yes No No + # + + t <- as.data.frame(Titanic) + t1 <- t[t$Freq > 0, -5] + df <- suppressWarnings(createDataFrame(t1)) + m <- spark.naiveBayes(df, Survived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, "Yes"]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables["Yes", "Age_Adult"]), 0.5714286, tolerance = 1e-6) + p <- collect(select(predict(m, df), "prediction")) + expect_equal(p$prediction, c("Yes", "Yes", "Yes", "Yes", "No", "No", "Yes", "Yes", "No", "No", + "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "Yes", "No", "No", + "Yes", "Yes", "No", "No")) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-naiveBayes", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + expect_equal(s$apriori, s2$apriori) + expect_equal(s$tables, s2$tables) + + unlink(modelPath) + + # Test e1071::naiveBayes + if (requireNamespace("e1071", quietly = TRUE)) { + expect_error(m <- e1071::naiveBayes(Survived ~ ., data = t1), NA) + expect_equal(as.character(predict(m, t1[1, ])), "Yes") + } + + # Test numeric response variable + t1$NumericSurvived <- ifelse(t1$Survived == "No", 0, 1) + t2 <- t1[-4] + df <- suppressWarnings(createDataFrame(t2)) + m <- spark.naiveBayes(df, NumericSurvived ~ ., smoothing = 0.0) + s <- summary(m) + expect_equal(as.double(s$apriori[1, 1]), 0.5833333, tolerance = 1e-6) + expect_equal(sum(s$apriori), 1) + expect_equal(as.double(s$tables[1, "Age_Adult"]), 0.5714286, tolerance = 1e-6) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R new file mode 100644 index 0000000000000..df8e5968b27f4 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -0,0 +1,318 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib clustering algorithms") + +# Tests for MLlib clustering algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.bisectingKmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(newIris)) + + take(training, 1) + + model <- spark.bisectingKmeans(data = training, ~ .) + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test fitted works on Bisecting KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), + c(0, 1, 2, 3)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + k <- summary.model$k + expect_equal(k, 4) + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), + c(0, 1, 2, 3)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-bisectingkmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) +}) + +test_that("spark.gaussianMixture", { + # R code to reproduce the result. + # nolint start + #' library(mvtnorm) + #' set.seed(1) + #' a <- rmvnorm(7, c(0, 0)) + #' b <- rmvnorm(8, c(10, 10)) + #' data <- rbind(a, b) + #' model <- mvnormalmixEM(data, k = 2) + #' model$lambda + # + # [1] 0.4666667 0.5333333 + # + #' model$mu + # + # [1] 0.11731091 -0.06192351 + # [1] 10.363673 9.897081 + # + #' model$sigma + # + # [[1]] + # [,1] [,2] + # [1,] 0.62049934 0.06880802 + # [2,] 0.06880802 1.27431874 + # + # [[2]] + # [,1] [,2] + # [1,] 0.2961543 0.160783 + # [2,] 0.1607830 1.008878 + # + #' model$loglik + # + # [1] -46.89499 + # nolint end + data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808), + list(0.3295078, -0.8204684), list(0.4874291, 0.7383247), + list(0.5757814, -0.3053884), list(1.5117812, 0.3898432), + list(-0.6212406, -2.2146999), list(11.1249309, 9.9550664), + list(9.9838097, 10.9438362), list(10.8212212, 10.5939013), + list(10.9189774, 10.7821363), list(10.0745650, 8.0106483), + list(10.6198257, 9.9438713), list(9.8442045, 8.5292476), + list(9.5218499, 10.4179416)) + df <- createDataFrame(data, c("x1", "x2")) + model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2) + stats <- summary(model) + rLambda <- c(0.4666667, 0.5333333) + rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081) + rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874, + 0.2961543, 0.160783, 0.1607830, 1.008878) + rLoglik <- -46.89499 + expect_equal(stats$lambda, rLambda, tolerance = 1e-3) + expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3) + expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3) + expect_equal(unlist(stats$loglik), rLoglik, tolerance = 1e-3) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$lambda, stats2$lambda) + expect_equal(unlist(stats$mu), unlist(stats2$mu)) + expect_equal(unlist(stats$sigma), unlist(stats2$sigma)) + expect_equal(unlist(stats$loglik), unlist(stats2$loglik)) + + unlink(modelPath) +}) + +test_that("spark.kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(newIris)) + + take(training, 1) + + model <- spark.kmeans(data = training, ~ ., k = 2, maxIter = 10, initMode = "random") + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(sort(unique(statsModel$cluster)), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + k <- summary.model$k + expect_equal(k, 2) + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) + + # test summary coefficients return matrix type + expect_true(class(summary.model$coefficients) == "matrix") + expect_true(class(summary.model$coefficients[1, ]) == "numeric") + + # Test model save/load + modelPath <- tempfile(pattern = "spark-kmeans", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + summary2 <- summary(model2) + expect_equal(sort(unlist(summary.model$size)), sort(unlist(summary2$size))) + expect_equal(summary.model$coefficients, summary2$coefficients) + expect_true(!summary.model$is.loaded) + expect_true(summary2$is.loaded) + + unlink(modelPath) + + # Test Kmeans on dataset that is sensitive to seed value + col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0) + cols <- as.data.frame(cbind(col1, col2, col3)) + df <- createDataFrame(cols) + + model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 1, tol = 1E-5) + model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10, + initMode = "random", seed = 22222, tol = 1E-5) + + summary.model1 <- summary(model1) + summary.model2 <- summary(model2) + cluster1 <- summary.model1$cluster + cluster2 <- summary.model2$cluster + clusterSize1 <- summary.model1$clusterSize + clusterSize2 <- summary.model2$clusterSize + + # The predicted clusters are different + expect_equal(sort(collect(distinct(select(cluster1, "prediction")))$prediction), + c(0, 1, 2, 3)) + expect_equal(sort(collect(distinct(select(cluster2, "prediction")))$prediction), + c(0, 1, 2)) + expect_equal(clusterSize1, 4) + expect_equal(clusterSize2, 3) +}) + +test_that("spark.lda with libsvm", { + text <- read.df(absoluteSparkPath("data/mllib/sample_lda_libsvm_data.txt"), source = "libsvm") + model <- spark.lda(text, optimizer = "em") + + stats <- summary(model, 10) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + trainingLogLikelihood <- stats$trainingLogLikelihood + logPrior <- stats$logPrior + + expect_true(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 11) + expect_true(is.null(vocabulary)) + expect_true(trainingLogLikelihood <= 0 & !is.na(trainingLogLikelihood)) + expect_true(logPrior <= 0 & !is.na(logPrior)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_true(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_equal(vocabulary, stats2$vocabulary) + expect_equal(trainingLogLikelihood, stats2$trainingLogLikelihood) + expect_equal(logPrior, stats2$logPrior) + + unlink(modelPath) +}) + +test_that("spark.lda with text input", { + skip_on_cran() + + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) + model <- spark.lda(text, optimizer = "online", features = "value") + + stats <- summary(model) + isDistributed <- stats$isDistributed + logLikelihood <- stats$logLikelihood + logPerplexity <- stats$logPerplexity + vocabSize <- stats$vocabSize + topics <- stats$topicTopTerms + weights <- stats$topicTopTermsWeights + vocabulary <- stats$vocabulary + trainingLogLikelihood <- stats$trainingLogLikelihood + logPrior <- stats$logPrior + + expect_false(isDistributed) + expect_true(logLikelihood <= 0 & is.finite(logLikelihood)) + expect_true(logPerplexity >= 0 & is.finite(logPerplexity)) + expect_equal(vocabSize, 10) + expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9"))) + expect_true(is.na(trainingLogLikelihood)) + expect_true(is.na(logPrior)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + + expect_false(stats2$isDistributed) + expect_equal(logLikelihood, stats2$logLikelihood) + expect_equal(logPerplexity, stats2$logPerplexity) + expect_equal(vocabSize, stats2$vocabSize) + expect_true(all.equal(vocabulary, stats2$vocabulary)) + expect_true(is.na(stats2$trainingLogLikelihood)) + expect_true(is.na(stats2$logPrior)) + + unlink(modelPath) +}) + +test_that("spark.posterior and spark.perplexity", { + skip_on_cran() + + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) + model <- spark.lda(text, features = "value", k = 3) + + # Assert perplexities are equal + stats <- summary(model) + logPerplexity <- spark.perplexity(model, text) + expect_equal(logPerplexity, stats$logPerplexity) + + # Assert the sum of every topic distribution is equal to 1 + posterior <- spark.posterior(model, text) + local.posterior <- collect(posterior)$topicDistribution + expect_equal(length(local.posterior), sum(unlist(local.posterior))) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R b/R/pkg/inst/tests/testthat/test_mllib_fpm.R new file mode 100644 index 0000000000000..1fa5375f9da31 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R @@ -0,0 +1,83 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib frequent pattern mining") + +# Tests for MLlib frequent pattern mining algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +test_that("spark.fpGrowth", { + data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,2", + "1,2,3", + "1,3" + ))), "split(items, ',') as items") + + model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, numPartitions = 1) + + itemsets <- collect(spark.freqItemsets(model)) + + expected_itemsets <- data.frame( + items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), list("1"))), + freq = c(2, 2, 3, 3, 4) + ) + + expect_equivalent(expected_itemsets, itemsets) + + expected_association_rules <- data.frame( + antecedent = I(list(list("2"), list("3"))), + consequent = I(list(list("1"), list("1"))), + confidence = c(1, 1) + ) + + expect_equivalent(expected_association_rules, collect(spark.associationRules(model))) + + new_data <- selectExpr(createDataFrame(data.frame(items = c( + "1,2", + "1,3", + "2,3" + ))), "split(items, ',') as items") + + expected_predictions <- data.frame( + items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))), + prediction = I(list(list(), list(), list("1"))) + ) + + expect_equivalent(expected_predictions, collect(predict(model, new_data))) + + modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp") + write.ml(model, modelPath, overwrite = TRUE) + loaded_model <- read.ml(modelPath) + + expect_equivalent( + itemsets, + collect(spark.freqItemsets(loaded_model))) + + unlink(modelPath) + + model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8) + expect_equal( + count(spark.freqItemsets(model_without_numpartitions)), + count(spark.freqItemsets(model)) + ) + +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_recommendation.R b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R new file mode 100644 index 0000000000000..e3e2b15c71361 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_recommendation.R @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib recommendation algorithms") + +# Tests for MLlib recommendation algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +test_that("spark.als", { + data <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), + list(2, 1, 1.0), list(2, 2, 5.0)) + df <- createDataFrame(data, c("user", "item", "score")) + model <- spark.als(df, ratingCol = "score", userCol = "user", itemCol = "item", + rank = 10, maxIter = 5, seed = 0, regParam = 0.1) + stats <- summary(model) + expect_equal(stats$rank, 10) + test <- createDataFrame(list(list(0, 2), list(1, 0), list(2, 0)), c("user", "item")) + predictions <- collect(predict(model, test)) + + expect_equal(predictions$prediction, c(-0.1380762, 2.6258414, -1.5018409), + tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-als", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats2$rating, "score") + userFactors <- collect(stats$userFactors) + itemFactors <- collect(stats$itemFactors) + userFactors2 <- collect(stats2$userFactors) + itemFactors2 <- collect(stats2$itemFactors) + + orderUser <- order(userFactors$id) + orderUser2 <- order(userFactors2$id) + expect_equal(userFactors$id[orderUser], userFactors2$id[orderUser2]) + expect_equal(userFactors$features[orderUser], userFactors2$features[orderUser2]) + + orderItem <- order(itemFactors$id) + orderItem2 <- order(itemFactors2$id) + expect_equal(itemFactors$id[orderItem], itemFactors2$id[orderItem2]) + expect_equal(itemFactors$features[orderItem], itemFactors2$features[orderItem2]) + + unlink(modelPath) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R new file mode 100644 index 0000000000000..44c98be906d81 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -0,0 +1,476 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib regression algorithms, except for tree-based algorithms") + +# Tests for MLlib regression algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +test_that("formula of spark.glm", { + skip_on_cran() + + training <- suppressWarnings(createDataFrame(iris)) + # directly calling the spark API + # dot minus and intercept vs native glm + model <- spark.glm(training, Sepal_Width ~ . - Species + 0) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # feature interaction vs native glm + model <- spark.glm(training, Sepal_Width ~ Species:Sepal_Length) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # glm should work with long formula + training <- suppressWarnings(createDataFrame(iris)) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- spark.glm(training, LongLongLongLongLongName ~ VeryLongLongLongLonLongName + + AnotherLongLongLongLongName) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("spark.glm and predict", { + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # poisson family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + + # tweedie family + model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) +}) + +test_that("spark.glm summary", { + # gaussian family + training <- suppressWarnings(createDataFrame(iris)) + stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species)) + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + + # test summary coefficients return matrix type + expect_true(class(stats$coefficients) == "matrix") + expect_true(class(stats$coefficients[, 1]) == "numeric") + + coefs <- stats$coefficients + rCoefs <- rStats$coefficients + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + out <- capture.output(print(stats)) + expect_match(out[2], "Deviance Residuals:") + expect_true(any(grepl("AIC: 59.22", out))) + + # binomial family + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width, + family = binomial(link = "logit"))) + + rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit"))) + + coefs <- stats$coefficients + rCoefs <- rStats$coefficients + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test spark.glm works with weighted dataset + a1 <- c(0, 1, 2, 3) + a2 <- c(5, 2, 1, 3) + w <- c(1, 2, 3, 4) + b <- c(1, 0, 1, 0) + data <- as.data.frame(cbind(a1, a2, w, b)) + df <- createDataFrame(data) + + stats <- summary(spark.glm(df, b ~ a1 + a2, family = "binomial", weightCol = "w")) + rStats <- summary(glm(b ~ a1 + a2, family = "binomial", data = data, weights = w)) + + coefs <- stats$coefficients + rCoefs <- rStats$coefficients + expect_true(all(abs(rCoefs - coefs) < 1e-3)) + expect_true(all(rownames(stats$coefficients) == c("(Intercept)", "a1", "a2"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test summary works on base GLM models + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) + + # Test spark.glm works with regularization parameter + data <- as.data.frame(cbind(a1, a2, b)) + df <- suppressWarnings(createDataFrame(data)) + regStats <- summary(spark.glm(df, b ~ a1 + a2, regParam = 1.0)) + expect_equal(regStats$aic, 13.32836, tolerance = 1e-4) # 13.32836 is from summary() result + + # Test spark.glm works on collinear data + A <- matrix(c(1, 2, 3, 4, 2, 4, 6, 8), 4, 2) + b <- c(1, 2, 3, 4) + data <- as.data.frame(cbind(A, b)) + df <- createDataFrame(data) + stats <- summary(spark.glm(df, b ~ . - 1)) + coefs <- stats$coefficients + expect_true(all(abs(c(0.5, 0.25) - coefs) < 1e-4)) +}) + +test_that("spark.glm save/load", { + skip_on_cran() + + training <- suppressWarnings(createDataFrame(iris)) + m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) + s <- summary(m) + + modelPath <- tempfile(pattern = "spark-glm", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + + expect_equal(s$coefficients, s2$coefficients) + expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) + expect_equal(s$dispersion, s2$dispersion) + expect_equal(s$null.deviance, s2$null.deviance) + expect_equal(s$deviance, s2$deviance) + expect_equal(s$df.null, s2$df.null) + expect_equal(s$df.residual, s2$df.residual) + expect_equal(s$aic, s2$aic) + expect_equal(s$iter, s2$iter) + expect_true(!s$is.loaded) + expect_true(s2$is.loaded) + + unlink(modelPath) +}) + +test_that("formula of glm", { + skip_on_cran() + + training <- suppressWarnings(createDataFrame(iris)) + # dot minus and intercept vs native glm + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # feature interaction vs native glm + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # glm should work with long formula + training <- suppressWarnings(createDataFrame(iris)) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + +test_that("glm and predict", { + skip_on_cran() + + training <- suppressWarnings(createDataFrame(iris)) + # gaussian family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # poisson family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + + # tweedie family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = "tweedie", var.power = 1.2, link.power = 0.0) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + + # manual calculation of the R predicted values to avoid dependence on statmod + #' library(statmod) + #' rModel <- glm(Sepal.Width ~ Sepal.Length + Species, data = iris, + #' family = tweedie(var.power = 1.2, link.power = 0.0)) + #' print(coef(rModel)) + + rCoef <- c(0.6455409, 0.1169143, -0.3224752, -0.3282174) + rVals <- exp(as.numeric(model.matrix(Sepal.Width ~ Sepal.Length + Species, + data = iris) %*% rCoef)) + expect_true(all(abs(rVals - vals) < 1e-5), rVals - vals) + + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) +}) + +test_that("glm summary", { + skip_on_cran() + + # gaussian family + training <- suppressWarnings(createDataFrame(iris)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + + coefs <- stats$coefficients + rCoefs <- rStats$coefficients + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # binomial family + df <- suppressWarnings(createDataFrame(iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = binomial(link = "logit"))) + + rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit"))) + + coefs <- stats$coefficients + rCoefs <- rStats$coefficients + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test summary works on base GLM models + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) + +test_that("glm save/load", { + skip_on_cran() + + training <- suppressWarnings(createDataFrame(iris)) + m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) + s <- summary(m) + + modelPath <- tempfile(pattern = "glm", fileext = ".tmp") + write.ml(m, modelPath) + expect_error(write.ml(m, modelPath)) + write.ml(m, modelPath, overwrite = TRUE) + m2 <- read.ml(modelPath) + s2 <- summary(m2) + + expect_equal(s$coefficients, s2$coefficients) + expect_equal(rownames(s$coefficients), rownames(s2$coefficients)) + expect_equal(s$dispersion, s2$dispersion) + expect_equal(s$null.deviance, s2$null.deviance) + expect_equal(s$deviance, s2$deviance) + expect_equal(s$df.null, s2$df.null) + expect_equal(s$df.residual, s2$df.residual) + expect_equal(s$aic, s2$aic) + expect_equal(s$iter, s2$iter) + expect_true(!s$is.loaded) + expect_true(s2$is.loaded) + + unlink(modelPath) +}) + +test_that("spark.isoreg", { + label <- c(7.0, 5.0, 3.0, 5.0, 1.0) + feature <- c(0.0, 1.0, 2.0, 3.0, 4.0) + weight <- c(1.0, 1.0, 1.0, 1.0, 1.0) + data <- as.data.frame(cbind(label, feature, weight)) + df <- createDataFrame(data) + + model <- spark.isoreg(df, label ~ feature, isotonic = FALSE, + weightCol = "weight") + # only allow one variable on the right hand side of the formula + expect_error(model2 <- spark.isoreg(df, ~., isotonic = FALSE)) + result <- summary(model) + expect_equal(result$predictions, list(7, 5, 4, 4, 1)) + + # Test model prediction + predict_data <- list(list(-2.0), list(-1.0), list(0.5), + list(0.75), list(1.0), list(2.0), list(9.0)) + predict_df <- createDataFrame(predict_data, c("feature")) + predict_result <- collect(select(predict(model, predict_df), "prediction")) + expect_equal(predict_result$prediction, c(7.0, 7.0, 6.0, 5.5, 5.0, 4.0, 1.0)) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-isoreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + expect_equal(result, summary(model2)) + + unlink(modelPath) +}) + +test_that("spark.survreg", { + # R code to reproduce the result. + # + #' rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + #' x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + #' library(survival) + #' model <- survreg(Surv(time, status) ~ x + sex, rData) + #' summary(model) + #' predict(model, data) + # + # -- output of 'summary(model)' + # + # Value Std. Error z p + # (Intercept) 1.315 0.270 4.88 1.07e-06 + # x -0.190 0.173 -1.10 2.72e-01 + # sex -0.253 0.329 -0.77 4.42e-01 + # Log(scale) -1.160 0.396 -2.93 3.41e-03 + # + # -- output of 'predict(model, data)' + # + # 1 2 3 4 5 6 7 + # 3.724591 2.545368 3.079035 3.079035 2.390146 2.891269 2.891269 + # + data <- list(list(4, 1, 0, 0), list(3, 1, 2, 0), list(1, 1, 1, 0), + list(1, 0, 1, 0), list(2, 1, 1, 1), list(2, 1, 0, 1), list(3, 0, 0, 1)) + df <- createDataFrame(data, c("time", "status", "x", "sex")) + model <- spark.survreg(df, Surv(time, status) ~ x + sex) + stats <- summary(model) + coefs <- as.vector(stats$coefficients[, 1]) + rCoefs <- c(1.3149571, -0.1903409, -0.2532618, -1.1599800) + expect_equal(coefs, rCoefs, tolerance = 1e-4) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "x", "sex", "Log(scale)"))) + p <- collect(select(predict(model, df), "prediction")) + expect_equal(p$prediction, c(3.724591, 2.545368, 3.079035, 3.079035, + 2.390146, 2.891269, 2.891269), tolerance = 1e-4) + + # Test model save/load + modelPath <- tempfile(pattern = "spark-survreg", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + coefs2 <- as.vector(stats2$coefficients[, 1]) + expect_equal(coefs, coefs2) + expect_equal(rownames(stats$coefficients), rownames(stats2$coefficients)) + + unlink(modelPath) + + # Test survival::survreg + if (requireNamespace("survival", quietly = TRUE)) { + rData <- list(time = c(4, 3, 1, 1, 2, 2, 3), status = c(1, 1, 1, 0, 1, 1, 0), + x = c(0, 2, 1, 1, 1, 0, 0), sex = c(0, 0, 0, 0, 1, 1, 1)) + expect_error( + model <- survival::survreg(formula = survival::Surv(time, status) ~ x + sex, data = rData), + NA) + expect_equal(predict(model, rData)[[1]], 3.724591, tolerance = 1e-4) + } +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_stat.R b/R/pkg/inst/tests/testthat/test_mllib_stat.R new file mode 100644 index 0000000000000..1600833a5d03a --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_stat.R @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib statistics algorithms") + +# Tests for MLlib statistics algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +test_that("spark.kstest", { + data <- data.frame(test = c(0.1, 0.15, 0.2, 0.3, 0.25, -1, -0.5)) + df <- createDataFrame(data) + testResult <- spark.kstest(df, "test", "norm") + stats <- summary(testResult) + + rStats <- ks.test(data$test, "pnorm", alternative = "two.sided") + + expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) + expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) + expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + testResult <- spark.kstest(df, "test", "norm", -0.5) + stats <- summary(testResult) + + rStats <- ks.test(data$test, "pnorm", -0.5, 1, alternative = "two.sided") + + expect_equal(stats$p.value, rStats$p.value, tolerance = 1e-4) + expect_equal(stats$statistic, unname(rStats$statistic), tolerance = 1e-4) + expect_match(capture.output(stats)[1], "Kolmogorov-Smirnov test summary:") + + # Test print.summary.KSTest + printStats <- capture.output(print.summary.KSTest(stats)) + expect_match(printStats[1], "Kolmogorov-Smirnov test summary:") + expect_match(printStats[5], + "Low presumption against null hypothesis: Sample follows theoretical distribution. ") +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_mllib_tree.R b/R/pkg/inst/tests/testthat/test_mllib_tree.R new file mode 100644 index 0000000000000..146bc2878e263 --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_mllib_tree.R @@ -0,0 +1,212 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib tree-based algorithms") + +# Tests for MLlib tree-based algorithms in SparkR +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +absoluteSparkPath <- function(x) { + sparkHome <- sparkR.conf("spark.home") + file.path(sparkHome, x) +} + +test_that("spark.gbt", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.gbt(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_equal(stats$formula, "Employed ~ .") + expect_equal(stats$numFeatures, 6) + expect_equal(length(stats$treeWeights), 20) + + modelPath <- tempfile(pattern = "spark-gbtRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + # label must be binary - GBTClassifier currently only supports binary classification. + iris2 <- iris[iris$Species != "virginica", ] + data <- suppressWarnings(createDataFrame(iris2)) + model <- spark.gbt(data, Species ~ Petal_Length + Petal_Width, "classification") + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + predictions <- collect(predict(model, data))$prediction + # test string prediction values + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-gbtClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) + df <- suppressWarnings(createDataFrame(iris2)) + m <- spark.gbt(df, NumericSpecies ~ ., type = "classification") + s <- summary(m) + # test numeric prediction values + expect_equal(iris2$NumericSpecies, as.double(collect(predict(m, df))$prediction)) + expect_equal(s$numFeatures, 5) + expect_equal(s$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + # spark.gbt classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_binary_classification_data.txt"), + source = "libsvm") + model <- spark.gbt(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 692) +}) + +test_that("spark.randomForest", { + # regression + data <- suppressWarnings(createDataFrame(longley)) + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 1) + + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.323, 61.122, 60.171, 61.187, + 63.221, 63.639, 64.989, 63.761, + 66.019, 67.857, 68.169, 66.513, + 68.655, 69.564, 69.331, 70.551), + tolerance = 1e-4) + + stats <- summary(model) + expect_equal(stats$numTrees, 1) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + + model <- spark.randomForest(data, Employed ~ ., "regression", maxDepth = 5, maxBins = 16, + numTrees = 20, seed = 123) + predictions <- collect(predict(model, data)) + expect_equal(predictions$prediction, c(60.32820, 61.22315, 60.69025, 62.11070, + 63.53160, 64.05470, 65.12710, 64.30450, + 66.70910, 67.86125, 68.08700, 67.21865, + 68.89275, 69.53180, 69.39640, 69.68250), + tolerance = 1e-4) + stats <- summary(model) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + modelPath <- tempfile(pattern = "spark-randomForestRegression", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$formula, stats2$formula) + expect_equal(stats$numFeatures, stats2$numFeatures) + expect_equal(stats$features, stats2$features) + expect_equal(stats$featureImportances, stats2$featureImportances) + expect_equal(stats$numTrees, stats2$numTrees) + expect_equal(stats$maxDepth, stats2$maxDepth) + expect_equal(stats$treeWeights, stats2$treeWeights) + + unlink(modelPath) + + # classification + data <- suppressWarnings(createDataFrame(iris)) + model <- spark.randomForest(data, Species ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + expect_error(capture.output(stats), NA) + expect_true(length(capture.output(stats)) > 6) + # Test string prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("setosa", predictions)), 50) + expect_equal(length(grep("versicolor", predictions)), 50) + + modelPath <- tempfile(pattern = "spark-randomForestClassification", fileext = ".tmp") + write.ml(model, modelPath) + expect_error(write.ml(model, modelPath)) + write.ml(model, modelPath, overwrite = TRUE) + model2 <- read.ml(modelPath) + stats2 <- summary(model2) + expect_equal(stats$depth, stats2$depth) + expect_equal(stats$numNodes, stats2$numNodes) + expect_equal(stats$numClasses, stats2$numClasses) + + unlink(modelPath) + + # Test numeric response variable + labelToIndex <- function(species) { + switch(as.character(species), + setosa = 0.0, + versicolor = 1.0, + virginica = 2.0 + ) + } + iris$NumericSpecies <- lapply(iris$Species, labelToIndex) + data <- suppressWarnings(createDataFrame(iris[-5])) + model <- spark.randomForest(data, NumericSpecies ~ Petal_Length + Petal_Width, "classification", + maxDepth = 5, maxBins = 16) + stats <- summary(model) + expect_equal(stats$numFeatures, 2) + expect_equal(stats$numTrees, 20) + expect_equal(stats$maxDepth, 5) + + # Test numeric prediction values + predictions <- collect(predict(model, data))$prediction + expect_equal(length(grep("1.0", predictions)), 50) + expect_equal(length(grep("2.0", predictions)), 50) + + # spark.randomForest classification can work on libsvm data + data <- read.df(absoluteSparkPath("data/mllib/sample_multiclass_classification_data.txt"), + source = "libsvm") + model <- spark.randomForest(data, label ~ features, "classification") + expect_equal(summary(model)$numFeatures, 4) +}) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 55972e1ba4693..52d4c93ed9599 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -33,12 +33,14 @@ numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3)) strPairs <- list(list(strList, strList), list(strList, strList)) # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -66,6 +68,8 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -86,6 +90,8 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { + skip_on_cran() + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -95,6 +101,8 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + skip_on_cran() + # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index a3d66c245a7d1..fb244e1d49e20 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -18,7 +18,7 @@ context("basic RDD functions") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -29,22 +29,30 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { - expect_equal(getNumPartitions(rdd), 2) - expect_equal(getNumPartitions(intRdd), 2) + skip_on_cran() + + expect_equal(getNumPartitionsRDD(rdd), 2) + expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { + skip_on_cran() + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + skip_on_cran() + + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { + skip_on_cran() + mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -56,30 +64,40 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { + skip_on_cran() + multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { + skip_on_cran() + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { + skip_on_cran() + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { + skip_on_cran() + flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { + skip_on_cran() + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -95,6 +113,8 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { + skip_on_cran() + vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -103,6 +123,8 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + skip_on_cran() + rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -117,6 +139,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + skip_on_cran() + # RDD rdd2 <- rdd # PipelinedRDD @@ -143,8 +167,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp expect_false(rdd2@env$isCached) tempDir <- tempfile(pattern = "checkpoint") - setCheckpointDir(sc, tempDir) - checkpoint(rdd2) + setCheckpointDirSC(sc, tempDir) + checkpointRDD(rdd2) expect_true(rdd2@env$isCheckpointed) rdd2 <- lapply(rdd2, function(x) x) @@ -158,6 +182,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { + skip_on_cran() + sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -167,6 +193,8 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { + skip_on_cran() + fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -175,6 +203,8 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { + skip_on_cran() + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -191,10 +221,14 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { + skip_on_cran() + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { + skip_on_cran() + # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -237,6 +271,8 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { + skip_on_cran() + multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -246,6 +282,8 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { + skip_on_cran() + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -258,6 +296,8 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { + skip_on_cran() + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -271,6 +311,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { + skip_on_cran() + nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -279,21 +321,29 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { + skip_on_cran() + max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { + skip_on_cran() + min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { + skip_on_cran() + sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { + skip_on_cran() + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -301,27 +351,31 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition r1 <- repartitionRDD(rdd, 2) - expect_equal(getNumPartitions(r1), 2L) + expect_equal(getNumPartitionsRDD(r1), 2L) count <- length(collectPartition(r1, 0L)) expect_true(count >= 8 && count <= 12) r2 <- repartitionRDD(rdd, 6) - expect_equal(getNumPartitions(r2), 6L) + expect_equal(getNumPartitionsRDD(r2), 6L) count <- length(collectPartition(r2, 0L)) expect_true(count >= 0 && count <= 4) # coalesce - r3 <- coalesce(rdd, 1) - expect_equal(getNumPartitions(r3), 1L) + r3 <- coalesceRDD(rdd, 1) + expect_equal(getNumPartitionsRDD(r3), 1L) count <- length(collectPartition(r3, 0L)) expect_equal(count, 20) }) test_that("sortBy() on RDDs", { + skip_on_cran() + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -333,6 +387,8 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -345,6 +401,8 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -357,6 +415,8 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { + skip_on_cran() + actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -366,6 +426,8 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -379,10 +441,12 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) - expected <- list(list("a", 0), list("b", 3), list("c", 1), - list("d", 4), list("e", 2)) + expected <- list(list("a", 0), list("b", 1), list("c", 4), + list("d", 2), list("e", 5)) expect_equal(actual, expected) rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L) @@ -393,6 +457,8 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -407,24 +473,32 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { + skip_on_cran() + rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { + skip_on_cran() + keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { + skip_on_cran() + values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { + skip_on_cran() + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -442,6 +516,8 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -471,6 +547,8 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -514,6 +592,8 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { + skip_on_cran() + l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -541,6 +621,8 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { + skip_on_cran() + l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -570,6 +652,8 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { + skip_on_cran() + # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -586,6 +670,8 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -610,6 +696,8 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -640,6 +728,8 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -667,6 +757,8 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -698,6 +790,8 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { + skip_on_cran() + numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -747,6 +841,8 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { + skip_on_cran() + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -765,11 +861,15 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { + skip_on_cran() + rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -794,6 +894,8 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { + skip_on_cran() + rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d38efab0fd1df..18320ea44b389 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -18,7 +18,7 @@ context("partitionBy, groupByKey, reduceByKey etc.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) # Data @@ -37,6 +37,8 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { + skip_on_cran() + grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -46,6 +48,8 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { + skip_on_cran() + grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -55,6 +59,8 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { + skip_on_cran() + reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -64,6 +70,8 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { + skip_on_cran() + reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -72,6 +80,8 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { + skip_on_cran() + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -81,6 +91,8 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { + skip_on_cran() + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -89,6 +101,8 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { + skip_on_cran() + stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -101,6 +115,8 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { + skip_on_cran() + # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -129,6 +145,8 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { + skip_on_cran() + # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -172,6 +190,8 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { + skip_on_cran() + # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -187,6 +207,8 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { + skip_on_cran() + kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -205,6 +227,8 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { + skip_on_cran() + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R new file mode 100644 index 0000000000000..a40981c188f7a --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +context("functions in sparkR.R") + +test_that("sparkCheckInstall", { + skip_on_cran() + + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- paste0(tempdir(), "/", "sparkHome") + dir.create(sparkHome) + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + unlink(sparkHome, recursive = TRUE) + + # "yarn-cluster, mesos-cluster" mode, SPARK_HOME was not set, + # and the SparkR job was submitted by "spark-submit" + sparkHome <- "" + master <- "" + deployMode <- "" + expect_true(is.null(sparkCheckInstall(sparkHome, master, deployMode))) + + # "yarn-client, mesos-client" mode, SPARK_HOME was not set + sparkHome <- "" + master <- "yarn-client" + deployMode <- "" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) + sparkHome <- "" + master <- "" + deployMode <- "client" + expect_error(sparkCheckInstall(sparkHome, master, deployMode)) +}) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 806019d7524ff..b633b78d5bb4d 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -60,7 +60,8 @@ unsetHiveContext <- function() { # Tests for SparkSQL functions in SparkR -sparkSession <- sparkR.session() +filesBefore <- list.files(path = sparkRDir, all.files = TRUE) +sparkSession <- sparkR.session(master = sparkRTestMaster) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockLines <- c("{\"name\":\"Michael\"}", @@ -88,16 +89,33 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) +# For test map type and struct type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + +if (.Platform$OS.type == "windows") { + Sys.setenv(TZ = "GMT") +} + test_that("calling sparkRSQL.init returns existing SQL context", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { + skip_on_cran() + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { + skip_on_cran() + expect_equal(sparkR.session(), sparkSession) }) @@ -132,7 +150,71 @@ test_that("structType and structField", { expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) +test_that("structField type strings", { + # positive cases + primitiveTypes <- list(byte = "ByteType", + integer = "IntegerType", + float = "FloatType", + double = "DoubleType", + string = "StringType", + binary = "BinaryType", + boolean = "BooleanType", + timestamp = "TimestampType", + date = "DateType", + tinyint = "ByteType", + smallint = "ShortType", + int = "IntegerType", + bigint = "LongType", + decimal = "DecimalType(10,0)") + + complexTypes <- list("map" = "MapType(StringType,IntegerType,true)", + "array" = "ArrayType(StringType,true)", + "struct" = "StructType(StructField(a,StringType,true))") + + typeList <- c(primitiveTypes, complexTypes) + typeStrings <- names(typeList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- typeList[[i]] + testField <- structField("_col", typeString) + expect_is(testField, "structField") + expect_true(testField$nullable()) + expect_equal(testField$dataType.toString(), expected) + } + + # negative cases + primitiveErrors <- list(Byte = "Byte", + INTEGER = "INTEGER", + numeric = "numeric", + character = "character", + raw = "raw", + logical = "logical", + short = "short", + varchar = "varchar", + long = "long", + char = "char") + + complexErrors <- list("map" = " integer", + "array" = "String", + "struct" = "string ", + "map " = "map ", + "array< string>" = " string", + "struct" = " string") + + errorList <- c(primitiveErrors, complexErrors) + typeStrings <- names(errorList) + + for (i in seq_along(typeStrings)){ + typeString <- typeStrings[i] + expected <- paste0("Unsupported type for SparkDataframe: ", errorList[[i]]) + expect_error(structField("_col", typeString), expected) + } +}) + test_that("create DataFrame from RDD", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -196,23 +278,47 @@ test_that("create DataFrame from RDD", { expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float"))) expect_equal(as.list(collect(where(df, df$name == "John"))), list(name = "John", age = 19L, height = 176.5)) + expect_equal(getNumPartitions(df), 1) + + df <- as.DataFrame(cars, numPartitions = 2) + expect_equal(getNumPartitions(df), 2) + df <- createDataFrame(cars, numPartitions = 3) + expect_equal(getNumPartitions(df), 3) + # validate limit by num of rows + df <- createDataFrame(cars, numPartitions = 60) + expect_equal(getNumPartitions(df), 50) + # validate when 1 < (length(coll) / numSlices) << length(coll) + df <- createDataFrame(cars, numPartitions = 20) + expect_equal(getNumPartitions(df), 20) + + df <- as.DataFrame(data.frame(0)) + expect_is(df, "SparkDataFrame") + df <- createDataFrame(list(list(1))) + expect_is(df, "SparkDataFrame") + df <- as.DataFrame(data.frame(0), numPartitions = 2) + # no data to partition, goes to 1 + expect_equal(getNumPartitions(df), 1) setHiveContext(sc) sql("CREATE TABLE people (name string, age double, height float)") df <- read.df(jsonPathNa, "json", schema) - invisible(insertInto(df, "people")) + insertInto(df, "people") expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) expect_equal(collect(sql("SELECT height from people WHERE name ='Bob'"))$height, c(176.5)) + sql("DROP TABLE people") unsetHiveContext() }) test_that("createDataFrame uses files for large objects", { + skip_on_cran() + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") - df <- suppressWarnings(createDataFrame(iris)) + df <- suppressWarnings(createDataFrame(iris, numPartitions = 3)) + expect_equal(getNumPartitions(df), 3) # Resetting the conf back to default value callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10)) @@ -268,6 +374,8 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { + skip_on_cran() + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -322,6 +430,8 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -433,6 +543,8 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { + skip_on_cran() + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -444,14 +556,9 @@ test_that("create DataFrame from a data.frame with complex types", { expect_equal(ldf$an_envir, collected$an_envir) }) -# For test map type and struct type in DataFrame -mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", - "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", - "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") -mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") -writeLines(mockLinesMapType, mapTypeJsonPath) - test_that("Collect DataFrame with complex types", { + skip_on_cran() + # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -539,6 +646,8 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -552,6 +661,8 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -566,20 +677,27 @@ test_that("jsonRDD() on a RDD with json string", { }) test_that("test tableNames and tables", { + count <- count(listTables()) + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") - expect_equal(length(tableNames()), 1) - tables <- tables() - expect_equal(count(tables), 1) + expect_equal(length(tableNames()), count + 1) + expect_equal(length(tableNames("default")), count + 1) + + tables <- listTables() + expect_equal(count(tables), count + 1) + expect_equal(count(tables()), count(tables)) + expect_true("tableName" %in% colnames(tables())) + expect_true(all(c("tableName", "database", "isTemporary") %in% colnames(tables()))) suppressWarnings(registerTempTable(df, "table2")) - tables <- tables() - expect_equal(count(tables), 2) + tables <- listTables() + expect_equal(count(tables), count + 2) suppressWarnings(dropTempTable("table1")) - dropTempView("table2") + expect_true(dropTempView("table2")) - tables <- tables() - expect_equal(count(tables), 0) + tables <- listTables() + expect_equal(count(tables), count + 0) }) test_that( @@ -589,7 +707,7 @@ test_that( newdf <- sql("SELECT * FROM table1 where name = 'Michael'") expect_is(newdf, "SparkDataFrame") expect_equal(count(newdf), 1) - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(df, "dfView") sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1")) @@ -600,16 +718,21 @@ test_that( expect_equal(ncol(sqlCast), 1) expect_equal(out[1], " x") expect_equal(out[2], "1 2") - dropTempView("dfView") + expect_true(dropTempView("dfView")) }) test_that("test cache, uncache and clearCache", { + skip_on_cran() + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") uncacheTable("table1") clearCache() - dropTempView("table1") + expect_true(dropTempView("table1")) + + expect_error(uncacheTable("foo"), + "Error in uncacheTable : no such table - Table or view 'foo' not found in database 'default'") }) test_that("insertInto() on a registered table", { @@ -630,13 +753,13 @@ test_that("insertInto() on a registered table", { insertInto(dfParquet2, "table1") expect_equal(count(sql("select * from table1")), 5) expect_equal(first(sql("select * from table1 order by age"))$name, "Michael") - dropTempView("table1") + expect_true(dropTempView("table1")) createOrReplaceTempView(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) expect_equal(count(sql("select * from table1")), 2) expect_equal(first(sql("select * from table1 order by age"))$name, "Bob") - dropTempView("table1") + expect_true(dropTempView("table1")) unlink(jsonPath2) unlink(parquetPath2) @@ -650,10 +773,12 @@ test_that("tableToDF() returns a new DataFrame", { expect_equal(count(tabledf), 3) tabledf2 <- tableToDF("table1") expect_equal(count(tabledf2), 3) - dropTempView("table1") + expect_true(dropTempView("table1")) }) test_that("toRDD() returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -661,6 +786,8 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -671,6 +798,8 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { + skip_on_cran() + # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -700,10 +829,12 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { + skip_on_cran() + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) - saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) + saveAsObjectFile(coalesceRDD(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) expect_is(objectIn, "RDD") @@ -712,6 +843,8 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -780,6 +913,8 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { + skip_on_cran() + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -818,6 +953,17 @@ test_that("cache(), storageLevel(), persist(), and unpersist() on a DataFrame", expect_true(is.data.frame(collect(df))) }) +test_that("setCheckpointDir(), checkpoint() on a DataFrame", { + checkpointDir <- file.path(tempdir(), "cproot") + expect_true(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) + + setCheckpointDir(checkpointDir) + df <- read.json(jsonPath) + df <- checkpoint(df) + expect_is(df, "SparkDataFrame") + expect_false(length(list.files(path = checkpointDir, all.files = TRUE)) == 0) +}) + test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- read.json(jsonPath) testSchema <- schema(df) @@ -847,6 +993,14 @@ test_that("names() colnames() set the column names", { colnames(df) <- c("col3", "col4") expect_equal(names(df)[1], "col3") + expect_error(names(df) <- NULL, "Invalid column names.") + expect_error(names(df) <- c("sepal.length", "sepal_width"), + "Column names cannot contain the '.' symbol.") + expect_error(names(df) <- c(1, 2), "Invalid column names.") + expect_error(names(df) <- c("a"), + "Column names must have the same length as the number of columns in the dataset.") + expect_error(names(df) <- c("1", NA), "Column names cannot be NA.") + expect_error(colnames(df) <- c("sepal.length", "sepal_width"), "Column names cannot contain the '.' symbol.") expect_error(colnames(df) <- c(1, 2), "Invalid column names.") @@ -868,6 +1022,12 @@ test_that("names() colnames() set the column names", { expect_equal(names(z)[3], "c") names(z)[3] <- "c2" expect_equal(names(z)[3], "c2") + + # Test subset assignment + colnames(df)[1] <- "col5" + expect_equal(colnames(df)[1], "col5") + names(df)[2] <- "col6" + expect_equal(names(df)[2], "col6") }) test_that("head() and first() return the correct data", { @@ -985,6 +1145,18 @@ test_that("select operators", { expect_is(df[[2]], "Column") expect_is(df[["age"]], "Column") + expect_warning(df[[1:2]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[1:2]]), "Column") + expect_warning(df[[c("name", "age")]], + "Subset index has length > 1. Only the first index is used.") + expect_is(suppressWarnings(df[[c("name", "age")]]), "Column") + + expect_warning(df[[1:2]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_warning(df[[c("name", "age")]] <- df[[1]], + "Subset index has length > 1. Only the first index is used.") + expect_is(df[, 1, drop = F], "SparkDataFrame") expect_equal(columns(df[, 1, drop = F]), c("name")) expect_equal(columns(df[, "age", drop = F]), c("age")) @@ -999,6 +1171,37 @@ test_that("select operators", { df$age2 <- df$age * 2 expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + df$age2 <- df[["age"]] * 3 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 3)), 2) + + df$age2 <- 21 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 21)), 3) + + df$age2 <- c(22) + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 22)), 3) + + expect_error(df$age3 <- c(22, NA), + "value must be a Column, literal value as atomic in length of 1, or NULL") + + df[["age2"]] <- 23 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 23)), 3) + + df[[3]] <- 24 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 24)), 3) + + df[[3]] <- df$age + expect_equal(count(where(df, df$age2 == df$age)), 2) + + df[["age2"]] <- df[["name"]] + expect_equal(count(where(df, df$age2 == df$name)), 3) + + expect_error(df[["age3"]] <- c(22, 23), + "value must be a Column, literal value as atomic in length of 1, or NULL") # Test parameter drop expect_equal(class(df[, 1]) == "SparkDataFrame", T) @@ -1027,6 +1230,16 @@ test_that("select with column", { expect_equal(columns(df4), c("name", "age")) expect_equal(count(df4), 3) + # Test select with alias + df5 <- alias(df, "table") + + expect_equal(columns(select(df5, column("table.name"))), "name") + expect_equal(columns(select(df5, "table.name")), "name") + + # Test that stats::alias is not masked + expect_is(alias(aov(yield ~ block + N * P * K, npk)), "listof") + + expect_error(select(df, c("name", "age"), "name"), "To select multiple columns, use a character vector or list for col") }) @@ -1117,7 +1330,16 @@ test_that("column calculation", { test_that("test HiveContext", { setHiveContext(sc) - df <- createExternalTable("json", jsonPath, "json") + + schema <- structType(structField("name", "string"), structField("age", "integer"), + structField("height", "float")) + createTable("people", source = "json", schema = schema) + df <- read.df(jsonPathNa, "json", schema) + insertInto(df, "people") + expect_equal(collect(sql("SELECT age from people WHERE name = 'Bob'"))$age, c(16)) + sql("DROP TABLE people") + + df <- createTable("json", jsonPath, "json") expect_is(df, "SparkDataFrame") expect_equal(count(df), 3) df2 <- sql("select * from json") @@ -1125,25 +1347,26 @@ test_that("test HiveContext", { expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "json2", "json", "append", path = jsonPath2)) + saveAsTable(df, "json2", "json", "append", path = jsonPath2) df3 <- sql("select * from json2") expect_is(df3, "SparkDataFrame") expect_equal(count(df3), 3) unlink(jsonPath2) hivetestDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "hivetestbl", path = hivetestDataPath)) + saveAsTable(df, "hivetestbl", path = hivetestDataPath) df4 <- sql("select * from hivetestbl") expect_is(df4, "SparkDataFrame") expect_equal(count(df4), 3) unlink(hivetestDataPath) parquetDataPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") - invisible(saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath)) + saveAsTable(df, "parquetest", "parquet", mode = "overwrite", path = parquetDataPath) df5 <- sql("select * from parquetest") expect_is(df5, "SparkDataFrame") expect_equal(count(df5), 3) unlink(parquetDataPath) + unsetHiveContext() }) @@ -1153,6 +1376,8 @@ test_that("column operators", { c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) c5 <- c2 ^ c3 ^ c4 + c6 <- c2 %<=>% c3 + c7 <- !c6 }) test_that("column functions", { @@ -1175,7 +1400,10 @@ test_that("column functions", { c16 <- is.nan(c) + isnan(c) + isNaN(c) c17 <- cov(c, c1) + cov("c", "c1") + covar_samp(c, c1) + covar_samp("c", "c1") c18 <- covar_pop(c, c1) + covar_pop("c", "c1") - c19 <- spark_partition_id() + c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) + c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") + c21 <- posexplode_outer(c) + explode_outer(c) + c22 <- not(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1191,6 +1419,11 @@ test_that("column functions", { expect_equal(collect(df2)[[3, 1]], FALSE) expect_equal(collect(df2)[[3, 2]], TRUE) + # Test that input_file_name() + actual_names <- sort(collect(distinct(select(df, input_file_name())))) + expect_equal(length(actual_names), 1) + expect_equal(basename(actual_names[1, 1]), basename(jsonPath)) + df3 <- select(df, between(df$name, c("Apache", "Spark"))) expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) @@ -1222,16 +1455,16 @@ test_that("column functions", { # Test struct() df <- createDataFrame(list(list(1L, 2L, 3L), list(4L, 5L, 6L)), schema = c("a", "b", "c")) - result <- collect(select(df, struct("a", "c"))) + result <- collect(select(df, alias(struct("a", "c"), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, c)" <- list(listToStruct(list(a = 1L, c = 3L)), - listToStruct(list(a = 4L, c = 6L))) + expected$"d" <- list(listToStruct(list(a = 1L, c = 3L)), + listToStruct(list(a = 4L, c = 6L))) expect_equal(result, expected) - result <- collect(select(df, struct(df$a, df$b))) + result <- collect(select(df, alias(struct(df$a, df$b), "d"))) expected <- data.frame(row.names = 1:2) - expected$"struct(a, b)" <- list(listToStruct(list(a = 1L, b = 2L)), - listToStruct(list(a = 4L, b = 5L))) + expected$"d" <- list(listToStruct(list(a = 1L, b = 2L)), + listToStruct(list(a = 4L, b = 5L))) expect_equal(result, expected) # Test encode(), decode() @@ -1244,9 +1477,9 @@ test_that("column functions", { # Test first(), last() df <- read.json(jsonPath) - expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age)))[[1]], NA_real_) expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) - expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age")))[[1]], NA_real_) expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) expect_equal(collect(select(df, last(df$age)))[[1]], 19) expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) @@ -1257,6 +1490,71 @@ test_that("column functions", { df <- createDataFrame(data.frame(x = c(2.5, 3.5))) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) + + # Test to_json(), from_json() + df <- sql("SELECT array(named_struct('name', 'Bob'), named_struct('name', 'Alice')) as people") + j <- collect(select(df, alias(to_json(df$people), "json"))) + expect_equal(j[order(j$json), ][1], "[{\"name\":\"Bob\"},{\"name\":\"Alice\"}]") + + df <- read.json(mapTypeJsonPath) + j <- collect(select(df, alias(to_json(df$info), "json"))) + expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") + df <- as.DataFrame(j) + schema <- structType(structField("age", "integer"), + structField("height", "double")) + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + + # passing option + df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) + schema2 <- structType(structField("date", "date")) + s <- collect(select(df, from_json(df$col, schema2))) + expect_equal(s[[1]][[1]], NA) + s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) + expect_is(s[[1]][[1]]$date, "Date") + expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") + + # check for unparseable + df <- as.DataFrame(list(list("a" = ""))) + expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) + + # check if array type in string is correctly supported. + jsonArr <- "[{\"name\":\"Bob\"}, {\"name\":\"Alice\"}]" + df <- as.DataFrame(list(list("people" = jsonArr))) + schema <- structType(structField("name", "string")) + arr <- collect(select(df, alias(from_json(df$people, schema, as.json.array = TRUE), "arrcol"))) + expect_equal(ncol(arr), 1) + expect_equal(nrow(arr), 1) + expect_is(arr[[1]][[1]], "list") + expect_equal(length(arr$arrcol[[1]]), 2) + expect_equal(arr$arrcol[[1]][[1]]$name, "Bob") + expect_equal(arr$arrcol[[1]][[2]]$name, "Alice") + + # Test create_array() and create_map() + df <- as.DataFrame(data.frame( + x = c(1.0, 2.0), y = c(-1.0, 3.0), z = c(-2.0, 5.0) + )) + + arrs <- collect(select(df, create_array(df$x, df$y, df$z))) + expect_equal(arrs[, 1], list(list(1, -1, -2), list(2, 3, 5))) + + maps <- collect(select( + df, create_map(lit("x"), df$x, lit("y"), df$y, lit("z"), df$z))) + + expect_equal( + maps[, 1], + lapply( + list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), + as.environment)) + + df <- as.DataFrame(data.frame(is_true = c(TRUE, FALSE, NA))) + expect_equal( + collect(select(df, alias(not(df$is_true), "is_false"))), + data.frame(is_false = c(FALSE, TRUE, NA)) + ) }) test_that("column binary mathfunctions", { @@ -1325,6 +1623,40 @@ test_that("string operators", { expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") + + l4 <- list(list(a = "a.b@c.d 1\\b")) + df4 <- createDataFrame(l4) + expect_equal( + collect(select(df4, split_string(df4$a, "\\s+")))[1, 1], + list(list("a.b@c.d", "1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.")))[1, 1], + list(list("a", "b@c", "d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "@")))[1, 1], + list(list("a.b", "c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], + list(list("a.b@c.d 1", "b")) + ) + + l5 <- list(list(a = "abc")) + df5 <- createDataFrame(l5) + expect_equal( + collect(select(df5, repeat_string(df5$a, 1L)))[1, 1], + "abc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, 3)))[1, 1], + "abcabcabc" + ) + expect_equal( + collect(select(df5, repeat_string(df5$a, -1)))[1, 1], + "" + ) }) test_that("date functions on a DataFrame", { @@ -1510,6 +1842,28 @@ test_that("group by, agg functions", { expect_true(abs(sd(1:2) - 0.7071068) < 1e-6) expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6) + # Test collect_list and collect_set + gd3_collections_local <- collect( + agg(gd3, collect_set(df8$age), collect_list(df8$age)) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]), + c(30) + ) + + expect_equal( + unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]), + c(30, 30) + ) + + expect_equal( + sort(unlist( + gd3_collections_local[gd3_collections_local$name == "Justin", 3] + )), + c(1, 19) + ) + unlink(jsonPath2) unlink(jsonPath3) }) @@ -1539,6 +1893,160 @@ test_that("pivot GroupedData column", { expect_error(collect(sum(pivot(groupBy(df, "year"), "course", list("R", "R")), "earnings"))) }) +test_that("test multi-dimensional aggregations with cube and rollup", { + df <- createDataFrame(data.frame( + id = 1:6, + year = c(2016, 2016, 2016, 2017, 2017, 2017), + salary = c(10000, 15000, 20000, 22000, 32000, 21000), + department = c("management", "rnd", "sales", "management", "rnd", "sales") + )) + + actual_cube <- collect( + orderBy( + agg( + cube(df, "year", "department"), + expr("sum(salary) AS total_salary"), + expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") + ), + "year", "department" + ) + ) + + expected_cube <- data.frame( + year = c(rep(NA, 4), rep(2016, 4), rep(2017, 4)), + department = rep(c(NA, "management", "rnd", "sales"), times = 3), + total_salary = c( + 120000, # Total + 10000 + 22000, 15000 + 32000, 20000 + 21000, # Department only + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + # Mean by department + mean(c(10000, 22000)), mean(c(15000, 32000)), mean(c(20000, 21000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + grouping_year = c( + 1, # global + 1, 1, 1, # by department + 0, # 2016 + 0, 0, 0, # 2016 by department + 0, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_department = c( + 1, # global + 0, 0, 0, # by department + 1, # 2016 + 0, 0, 0, # 2016 by department + 1, # 2017 + 0, 0, 0 # 2017 by department + ), + grouping_id = c( + 3, # 11 + 2, 2, 2, # 10 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_cube, expected_cube) + + # cube should accept column objects + expect_equal( + count(sum(cube(df, df$year, df$department), "salary")), + 12 + ) + + # cube without columns should result in a single aggregate + expect_equal( + collect(agg(cube(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) + + actual_rollup <- collect( + orderBy( + agg( + rollup(df, "year", "department"), + expr("sum(salary) AS total_salary"), expr("avg(salary) AS average_salary"), + alias(grouping_bit(df$year), "grouping_year"), + alias(grouping_bit(df$department), "grouping_department"), + alias(grouping_id(df$year, df$department), "grouping_id") + ), + "year", "department" + ) + ) + + expected_rollup <- data.frame( + year = c(NA, rep(2016, 4), rep(2017, 4)), + department = c(NA, rep(c(NA, "management", "rnd", "sales"), times = 2)), + total_salary = c( + 120000, # Total + 20000 + 15000 + 10000, # 2016 + 10000, 15000, 20000, # 2016 each department + 21000 + 32000 + 22000, # 2017 + 22000, 32000, 21000 # 2017 each department + ), + average_salary = c( + # Total + mean(c(20000, 15000, 10000, 21000, 32000, 22000)), + mean(c(10000, 15000, 20000)), # 2016 + 10000, 15000, 20000, # 2016 each department + mean(c(21000, 32000, 22000)), # 2017 + 22000, 32000, 21000 # 2017 each department + ), + grouping_year = c( + 1, # global + 0, # 2016 + 0, 0, 0, # 2016 each department + 0, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_department = c( + 1, # global + 1, # 2016 + 0, 0, 0, # 2016 each department + 1, # 2017 + 0, 0, 0 # 2017 each department + ), + grouping_id = c( + 3, # 11 + 1, # 01 + 0, 0, 0, # 00 + 1, # 01 + 0, 0, 0 # 00 + ), + stringsAsFactors = FALSE + ) + + expect_equal(actual_rollup, expected_rollup) + + # cube should accept column objects + expect_equal( + count(sum(rollup(df, df$year, df$department), "salary")), + 9 + ) + + # rollup without columns should result in a single aggregate + expect_equal( + collect(agg(rollup(df), expr("sum(salary) as total_salary"))), + data.frame(total_salary = 120000) + ) +}) + test_that("arrange() and orderBy() on a DataFrame", { df <- read.json(jsonPath) sorted <- arrange(df, df$age) @@ -1584,6 +2092,16 @@ test_that("filter() on a DataFrame", { filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + # test suites for %<=>% + dfNa <- read.json(jsonPathNa) + expect_equal(count(filter(dfNa, dfNa$age %<=>% 60)), 1) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% 60))), 5 - 1) + expect_equal(count(filter(dfNa, dfNa$age %<=>% NULL)), 3) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% NULL))), 5 - 3) + # match NA from two columns + expect_equal(count(filter(dfNa, dfNa$age %<=>% dfNa$height)), 2) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% dfNa$height))), 5 - 2) + # Test stats::filter is working #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) @@ -1686,14 +2204,32 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) + + execution_plan_broadcast <- capture.output( + explain(join(df1, broadcast(df2), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_broadcast))) }) -test_that("toJSON() returns an RDD of the correct values", { - df <- read.json(jsonPath) - testRDD <- toJSON(df) - expect_is(testRDD, "RDD") - expect_equal(getSerializedMode(testRDD), "string") - expect_equal(collectRDD(testRDD)[[1]], mockLines[1]) +test_that("toJSON() on DataFrame", { + df <- as.DataFrame(cars) + df_json <- toJSON(df) + expect_is(df_json, "SparkDataFrame") + expect_equal(colnames(df_json), c("value")) + expect_equal(head(df_json, 1), + data.frame(value = "{\"speed\":4.0,\"dist\":2.0}", stringsAsFactors = FALSE)) }) test_that("showDF()", { @@ -1742,6 +2278,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(count(unioned2), 12) expect_equal(first(unioned2)$name, "Michael") + df3 <- df2 + names(df3)[1] <- "newName" + expect_error(rbind(df, df3), + "Names of input data frames are different.") + expect_error(rbind(df, df2, df3), + "Names of input data frames are different.") + excepted <- arrange(except(df, df2), desc(df$age)) expect_is(unioned, "SparkDataFrame") expect_equal(count(excepted), 2) @@ -1776,6 +2319,13 @@ test_that("withColumn() and withColumnRenamed()", { expect_equal(length(columns(newDF)), 2) expect_equal(first(filter(newDF, df$name != "Michael"))$age, 32) + newDF <- withColumn(df, "age", 18) + expect_equal(length(columns(newDF)), 2) + expect_equal(first(newDF)$age, 18) + + expect_error(withColumn(df, "age", list("a")), + "Literal value must be atomic in length of 1") + newDF2 <- withColumnRenamed(df, "age", "newerAge") expect_equal(length(columns(newDF2)), 2) expect_equal(columns(newDF2)[1], "newerAge") @@ -1830,6 +2380,8 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -1851,6 +2403,8 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -1897,6 +2451,8 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { + skip_on_cran() + df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -1914,6 +2470,8 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { + skip_on_cran() + # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -1935,6 +2493,8 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2152,14 +2712,24 @@ test_that("sampleBy() on a DataFrame", { }) test_that("approxQuantile() on a DataFrame", { - l <- lapply(c(0:99), function(i) { i }) - df <- createDataFrame(l, "key") - quantiles <- approxQuantile(df, "key", c(0.5, 0.8), 0.0) - expect_equal(quantiles[[1]], 50) - expect_equal(quantiles[[2]], 80) + l <- lapply(c(0:99), function(i) { list(i, 99 - i) }) + df <- createDataFrame(l, list("a", "b")) + quantiles <- approxQuantile(df, "a", c(0.5, 0.8), 0.0) + expect_equal(quantiles, list(50, 80)) + quantiles2 <- approxQuantile(df, c("a", "b"), c(0.5, 0.8), 0.0) + expect_equal(quantiles2[[1]], list(50, 80)) + expect_equal(quantiles2[[2]], list(50, 80)) + + dfWithNA <- createDataFrame(data.frame(a = c(NA, 30, 19, 11, 28, 15), + b = c(-30, -19, NA, -11, -28, -15))) + quantiles3 <- approxQuantile(dfWithNA, c("a", "b"), c(0.5), 0.0) + expect_equal(quantiles3[[1]], list(28)) + expect_equal(quantiles3[[2]], list(-15)) }) test_that("SQL error message is returned from JVM", { + skip_on_cran() + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2168,6 +2738,8 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { + skip_on_cran() + expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2421,15 +2993,18 @@ test_that("repartition by columns on DataFrame", { ("Please, specify the number of partitions and/or a column\\(s\\)", retError), TRUE) # repartition by column and number of partitions - actual <- repartition(df, 3L, col = df$"a") + actual <- repartition(df, 3, col = df$"a") - # since we cannot access the number of partitions from dataframe, checking - # that at least the dimensions are identical + # Checking that at least the dimensions are identical expect_identical(dim(df), dim(actual)) + expect_equal(getNumPartitions(actual), 3L) # repartition by number of partitions actual <- repartition(df, 13L) expect_identical(dim(df), dim(actual)) + expect_equal(getNumPartitions(actual), 13L) + + expect_equal(getNumPartitions(coalesce(actual, 1L)), 1L) # a test case with a column and dapply schema <- structType(structField("a", "integer"), structField("avg", "double")) @@ -2445,6 +3020,25 @@ test_that("repartition by columns on DataFrame", { expect_equal(nrow(df1), 2) }) +test_that("coalesce, repartition, numPartitions", { + df <- as.DataFrame(cars, numPartitions = 5) + expect_equal(getNumPartitions(df), 5) + expect_equal(getNumPartitions(coalesce(df, 3)), 3) + expect_equal(getNumPartitions(coalesce(df, 6)), 5) + + df1 <- coalesce(df, 3) + expect_equal(getNumPartitions(df1), 3) + expect_equal(getNumPartitions(coalesce(df1, 6)), 5) + expect_equal(getNumPartitions(coalesce(df1, 4)), 4) + expect_equal(getNumPartitions(coalesce(df1, 2)), 2) + + df2 <- repartition(df1, 10) + expect_equal(getNumPartitions(df2), 10) + expect_equal(getNumPartitions(coalesce(df2, 13)), 10) + expect_equal(getNumPartitions(coalesce(df2, 7)), 7) + expect_equal(getNumPartitions(coalesce(df2, 3)), 3) +}) + test_that("gapply() and gapplyCollect() on a DataFrame", { df <- createDataFrame ( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), @@ -2563,6 +3157,8 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -2589,7 +3185,7 @@ test_that("createDataFrame sqlContext parameter backward compatibility", { # more tests for SPARK-16538 createOrReplaceTempView(df, "table") - SparkR::tables() + SparkR::listTables() SparkR::sql("SELECT 1") suppressWarnings(SparkR::sql(sqlContext, "SELECT * FROM table")) suppressWarnings(SparkR::dropTempTable(sqlContext, "table")) @@ -2612,7 +3208,7 @@ test_that("randomSplit", { expect_true(all(sapply(abs(counts / num - weights / sum(weights)), function(e) { e < 0.05 }))) }) -test_that("Setting and getting config on SparkSession", { +test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiWebUrl()", { # first, set it to a random but known value conf <- callJMethod(sparkSession, "conf") property <- paste0("spark.testing.", as.character(runif(1))) @@ -2636,9 +3232,14 @@ test_that("Setting and getting config on SparkSession", { expect_equal(appNameValue, "sparkSession test") expect_equal(testValue, value) expect_error(sparkR.conf("completely.dummy"), "Config 'completely.dummy' is not set") + + url <- sparkR.uiWebUrl() + expect_equal(substr(url, 1, 7), "http://") }) test_that("enableHiveSupport on SparkSession", { + skip_on_cran() + setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -2654,12 +3255,14 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + skip_on_cran() + df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls # DataFrameWriter.save() without path. expect_error(write.df(df, source = "csv"), - "Error in save : illegal argument - 'path' is not specified") + "Error in save : illegal argument - Expected exactly one path to be specified") expect_error(write.json(df, jsonPath), "Error in json : analysis error - path file:.*already exists") expect_error(write.text(df, jsonPath), @@ -2667,24 +3270,26 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume expect_error(write.orc(df, jsonPath), "Error in orc : analysis error - path file:.*already exists") expect_error(write.parquet(df, jsonPath), - "Error in parquet : analysis error - path file:.*already exists") + "Error in parquet : analysis error - path file:.*already exists") # Arguments checking in R side. expect_error(write.df(df, "data.tmp", source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) expect_error(write.df(df, path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(write.df(df, mode = TRUE), - "mode should be charactor or omitted. It is 'error' by default.") + "mode should be character or omitted. It is 'error' by default.") }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + skip_on_cran() + # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") @@ -2695,7 +3300,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # Arguments checking in R side. expect_error(read.df(path = c(3)), - "path should be charactor, NULL or omitted.") + "path should be character, NULL or omitted.") expect_error(read.df(jsonPath, source = c(1, 2)), paste("source should be character, NULL or omitted. It is the datasource specified", "in 'spark.sql.sources.default' configuration by default.")) @@ -2704,9 +3309,135 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { + ldf <- data.frame(col1 = c(0, 1, 2), + col2 = c(as.POSIXct("2017-01-01 00:00:01"), + NA, + as.POSIXct("2017-01-01 12:00:01")), + col3 = c(as.POSIXlt("2016-01-01 00:59:59"), + NA, + as.POSIXlt("2016-01-01 12:01:01"))) + sdf1 <- createDataFrame(ldf) + ldf1 <- collect(sdf1) + expect_equal(dtypes(sdf1), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf1$col1), "numeric") + expect_equal(class(ldf1$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf1$col3), c("POSIXct", "POSIXt")) + + # Columns with NAs at the top + sdf2 <- filter(sdf1, "col1 > 1") + ldf2 <- collect(sdf2) + expect_equal(dtypes(sdf2), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf2$col1), "numeric") + expect_equal(class(ldf2$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf2$col3), c("POSIXct", "POSIXt")) + + # Columns with only NAs, the type will also be cast to PRIMITIVE_TYPE + sdf3 <- filter(sdf1, "col1 == 0") + ldf3 <- collect(sdf3) + expect_equal(dtypes(sdf3), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf3$col1), "numeric") + expect_equal(class(ldf3$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) +}) + +test_that("catalog APIs, currentDatabase, setCurrentDatabase, listDatabases", { + expect_equal(currentDatabase(), "default") + expect_error(setCurrentDatabase("default"), NA) + expect_error(setCurrentDatabase("foo"), + "Error in setCurrentDatabase : analysis error - Database 'foo' does not exist") + dbs <- collect(listDatabases()) + expect_equal(names(dbs), c("name", "description", "locationUri")) + expect_equal(dbs[[1]], "default") +}) + +test_that("catalog APIs, listTables, listColumns, listFunctions", { + tb <- listTables() + count <- count(tables()) + expect_equal(nrow(tb), count) + expect_equal(colnames(tb), c("name", "database", "description", "tableType", "isTemporary")) + + createOrReplaceTempView(as.DataFrame(cars), "cars") + + tb <- listTables() + expect_equal(nrow(tb), count + 1) + tbs <- collect(tb) + expect_true(nrow(tbs[tbs$name == "cars", ]) > 0) + expect_error(listTables("bar"), + "Error in listTables : no such database - Database 'bar' not found") + + c <- listColumns("cars") + expect_equal(nrow(c), 2) + expect_equal(colnames(c), + c("name", "description", "dataType", "nullable", "isPartition", "isBucket")) + expect_equal(collect(c)[[1]][[1]], "speed") + expect_error(listColumns("foo", "default"), + "Error in listColumns : analysis error - Table 'foo' does not exist in database 'default'") + + f <- listFunctions() + expect_true(nrow(f) >= 200) # 250 + expect_equal(colnames(f), + c("name", "database", "description", "className", "isTemporary")) + expect_equal(take(orderBy(f, "className"), 1)$className, + "org.apache.spark.sql.catalyst.expressions.Abs") + expect_error(listFunctions("foo_db"), + "Error in listFunctions : analysis error - Database 'foo_db' does not exist") + + # recoverPartitions does not work with tempory view + expect_error(recoverPartitions("cars"), + "no such table - Table or view 'cars' not found in database 'default'") + expect_error(refreshTable("cars"), NA) + expect_error(refreshByPath("/"), NA) + + dropTempView("cars") +}) + +compare_list <- function(list1, list2) { + # get testthat to show the diff by first making the 2 lists equal in length + expect_equal(length(list1), length(list2)) + l <- max(length(list1), length(list2)) + length(list1) <- l + length(list2) <- l + expect_equal(sort(list1, na.last = TRUE), sort(list2, na.last = TRUE)) +} + +# This should always be the **very last test** in this test file. +test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() # skip because when run from R CMD check SPARK_HOME is not the current directory + + # Check that it is not creating any extra file. + # Does not check the tempdir which would be cleaned up after. + filesAfter <- list.files(path = sparkRDir, all.files = TRUE) + + expect_true(length(sparkRFilesBefore) > 0) + # first, ensure derby.log is not there + expect_false("derby.log" %in% filesAfter) + # second, ensure only spark-warehouse is created when calling SparkSession, enableHiveSupport = F + # note: currently all other test files have enableHiveSupport = F, so we capture the list of files + # before creating a SparkSession with enableHiveSupport = T at the top of this test file + # (filesBefore). The test here is to compare that (filesBefore) against the list of files before + # any test is run in run-all.R (sparkRFilesBefore). + # sparkRWhitelistSQLDirs is also defined in run-all.R, and should contain only 2 whitelisted dirs, + # here allow the first value, spark-warehouse, in the diff, everything else should be exactly the + # same as before any test is run. + compare_list(sparkRFilesBefore, setdiff(filesBefore, sparkRWhitelistSQLDirs[[1]])) + # third, ensure only spark-warehouse and metastore_db are created when enableHiveSupport = T + # note: as the note above, after running all tests in this file while enableHiveSupport = T, we + # check the list of files again. This time we allow both whitelisted dirs to be in the diff. + compare_list(sparkRFilesBefore, setdiff(filesAfter, sparkRWhitelistSQLDirs)) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) unlink(jsonPathNa) +unlink(complexTypeJsonPath) +unlink(mapTypeJsonPath) sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R new file mode 100644 index 0000000000000..b20b4312fbaae --- /dev/null +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("Structured Streaming") + +# Tests for Structured Streaming functions in SparkR + +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) + +jsonSubDir <- file.path("sparkr-test", "json", "") +if (.Platform$OS.type == "windows") { + # file.path removes the empty separator on Windows, adds it back + jsonSubDir <- paste0(jsonSubDir, .Platform$file.sep) +} +jsonDir <- file.path(tempdir(), jsonSubDir) +dir.create(jsonDir, recursive = TRUE) + +mockLines <- c("{\"name\":\"Michael\"}", + "{\"name\":\"Andy\", \"age\":30}", + "{\"name\":\"Justin\", \"age\":19}") +jsonPath <- tempfile(pattern = jsonSubDir, fileext = ".tmp") +writeLines(mockLines, jsonPath) + +mockLinesNa <- c("{\"name\":\"Bob\",\"age\":16,\"height\":176.5}", + "{\"name\":\"Alice\",\"age\":null,\"height\":164.3}", + "{\"name\":\"David\",\"age\":60,\"height\":null}") +jsonPathNa <- tempfile(pattern = jsonSubDir, fileext = ".tmp") + +schema <- structType(structField("name", "string"), + structField("age", "integer"), + structField("count", "double")) + +test_that("read.stream, write.stream, awaitTermination, stopQuery", { + skip_on_cran() + + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) + + writeLines(mockLinesNa, jsonPathNa) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) +}) + +test_that("print from explain, lastProgress, status, isActive", { + skip_on_cran() + + df <- read.stream("json", path = jsonDir, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") + + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") + expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) + expect_true(any(grepl("\"isTriggerActive\" : ", capture.output(status(q))))) + + expect_equal(queryName(q), "people2") + expect_true(isActive(q)) + + stopQuery(q) +}) + +test_that("Stream other format", { + skip_on_cran() + + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + df <- read.df(jsonPath, "json", schema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = schema) + expect_true(isStreaming(df)) + counts <- count(group_by(df, "name")) + q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") + + expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") + expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) + + expect_equal(queryName(q), "people3") + expect_true(any(grepl("\"description\" : \"FileStreamSource[[:print:]]+parquet", + capture.output(lastProgress(q))))) + expect_true(isActive(q)) + + stopQuery(q) + expect_true(awaitTermination(q, 1)) + expect_false(isActive(q)) + + unlink(parquetPath) +}) + +test_that("Non-streaming DataFrame", { + skip_on_cran() + + c <- as.DataFrame(cars) + expect_false(isStreaming(c)) + + expect_error(write.stream(c, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(writeStream : analysis error - 'writeStream' can be called only on ", + "streaming Dataset/DataFrame).*")) +}) + +test_that("Unsupported operation", { + skip_on_cran() + + # memory sink without aggregation + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) + expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), + paste0(".*(start : analysis error - Complete output mode not supported when there ", + "are no streaming aggregations on streaming DataFrames/Datasets).*")) +}) + +test_that("Terminated by error", { + skip_on_cran() + + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) + counts <- count(group_by(df, "name")) + # This would not fail before returning with a StreamingQuery, + # but could dump error log at just about the same time + expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), + NA) + + expect_error(awaitTermination(q, 5 * 1000), + paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", + " 'maxFilesPerTrigger', must be a positive integer).*")) + + expect_true(any(grepl("\"message\" : \"Terminated with exception: Invalid value", + capture.output(status(q))))) + expect_true(any(grepl("Streaming query has no progress", capture.output(lastProgress(q))))) + expect_equal(queryName(q), "people4") + expect_false(isActive(q)) + + stopQuery(q) +}) + +unlink(jsonPath) +unlink(jsonPathNa) + +sparkR.session.stop() diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index aaa532856c3d9..c00723ba31f4c 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -30,10 +30,12 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ", "raising me. But they're both dead now. I didn't kill them. Honest.") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { + skip_on_cran() + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 3b466066e9390..e8a961cb3e870 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -18,12 +18,14 @@ context("the textFile() function") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -36,6 +38,8 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -46,6 +50,8 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -64,6 +70,8 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -78,6 +86,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -92,6 +102,8 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -103,6 +115,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -128,6 +142,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -141,6 +157,8 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 607c407f04f97..02691f0f64314 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -18,11 +18,12 @@ context("functions in utils.R") # JavaSparkContext handle -sparkSession <- sparkR.session(enableHiveSupport = FALSE) +sparkSession <- sparkR.session(master = sparkRTestMaster, enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { + skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -40,6 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { + skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -134,7 +136,7 @@ test_that("cleanClosure on R functions", { # Test for broadcast variables. a <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) - aBroadcast <- broadcast(sc, a) + aBroadcast <- broadcastRDD(sc, a) normMultiply <- function(x) { norm(aBroadcast$value) * x } newnormMultiply <- SparkR:::cleanClosure(normMultiply) env <- environment(newnormMultiply) @@ -167,16 +169,20 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { - method <- "getSQLDataType" + skip_on_cran() + + method <- "createStructField" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, - "unknown"), + "col", "unknown", TRUE), error = function(e) { captureJVMException(e, method) }), - "Error in getSQLDataType : illegal argument - Invalid type unknown") + "parse error - .*DataType unknown.*not supported.") }) test_that("hashCode", { + skip_on_cran() + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) @@ -228,4 +234,12 @@ test_that("varargsToStrEnv", { expect_warning(varargsToStrEnv(1, 2, 3, 4), "Unnamed arguments ignored: 1, 2, 3, 4.") }) +test_that("basenameSansExtFromUrl", { + x <- paste0("http://people.apache.org/~pwendell/spark-nightly/spark-branch-2.1-bin/spark-2.1.1-", + "SNAPSHOT-2016_12_09_11_08-eb2d9bf-bin/spark-2.1.1-SNAPSHOT-bin-hadoop2.7.tgz") + expect_equal(basenameSansExtFromUrl(x), "spark-2.1.1-SNAPSHOT-bin-hadoop2.7") + z <- "http://people.apache.org/~pwendell/spark-releases/spark-2.1.0--hive.tar.gz" + expect_equal(basenameSansExtFromUrl(z), "spark-2.1.0--hive") +}) + sparkR.session.stop() diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R index 1d04656ac2594..9c6cba535d118 100644 --- a/R/pkg/tests/run-all.R +++ b/R/pkg/tests/run-all.R @@ -21,4 +21,19 @@ library(SparkR) # Turn all warnings into errors options("warn" = 2) +# Setup global test environment +# Install Spark first to set SPARK_HOME +install.spark() + +sparkRDir <- file.path(Sys.getenv("SPARK_HOME"), "R") +sparkRFilesBefore <- list.files(path = sparkRDir, all.files = TRUE) +sparkRWhitelistSQLDirs <- c("spark-warehouse", "metastore_db") +invisible(lapply(sparkRWhitelistSQLDirs, + function(x) { unlink(file.path(sparkRDir, x), recursive = TRUE, force = TRUE)})) + +sparkRTestMaster <- "local[1]" +if (identical(Sys.getenv("NOT_CRAN"), "true")) { + sparkRTestMaster <- "" +} + test_package("SparkR") diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index 80e876027bddb..13a399165c8b4 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -1,14 +1,32 @@ --- title: "SparkR - Practical Guide" output: - html_document: - theme: united + rmarkdown::html_vignette: toc: true toc_depth: 4 - toc_float: true - highlight: textmate +vignette: > + %\VignetteIndexEntry{SparkR - Practical Guide} + %\VignetteEngine{knitr::rmarkdown} + \usepackage[utf8]{inputenc} --- + + ## Overview SparkR is an R package that provides a light-weight frontend to use Apache Spark from R. With Spark `r packageVersion("SparkR")`, SparkR provides a distributed data frame implementation that supports data processing operations like selection, filtering, aggregation etc. and distributed machine learning using [MLlib](http://spark.apache.org/mllib/). @@ -26,7 +44,11 @@ library(SparkR) We use default settings in which it runs in local mode. It auto downloads Spark package in the background if no previous installation is found. For more details about setup, see [Spark Session](#SetupSparkSession). -```{r, message=FALSE, results="hide"} +```{r, include=FALSE} +install.spark() +sparkR.session(master = "local[1]") +``` +```{r, eval=FALSE} sparkR.session() ``` @@ -44,7 +66,7 @@ We can view the first few rows of the `SparkDataFrame` by `head` or `showDF` fun head(carsDF) ``` -Common data processing operations such as `filter`, `select` are supported on the `SparkDataFrame`. +Common data processing operations such as `filter` and `select` are supported on the `SparkDataFrame`. ```{r} carsSubDF <- select(carsDF, "model", "mpg", "hp") carsSubDF <- filter(carsSubDF, carsSubDF$hp >= 200) @@ -93,13 +115,13 @@ sparkR.session.stop() Different from many other R packages, to use SparkR, you need an additional installation of Apache Spark. The Spark installation will be used to run a backend process that will compile and execute SparkR programs. -If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). Alternatively, we provide an easy-to-use function `install.spark` to complete this process. You don't have to call it explicitly. We will check the installation when `sparkR.session` is called and `install.spark` function will be triggered automatically if no installation is found. +After installing the SparkR package, you can call `sparkR.session` as explained in the previous section to start and it will check for the Spark installation. If you are working with SparkR from an interactive shell (eg. R, RStudio) then Spark is downloaded and cached automatically if it is not found. Alternatively, we provide an easy-to-use function `install.spark` for running this manually. If you don't have Spark installed on the computer, you may download it from [Apache Spark Website](http://spark.apache.org/downloads.html). ```{r, eval=FALSE} install.spark() ``` -If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the Spark installation is. +If you already have Spark installed, you don't have to install again and can pass the `sparkHome` argument to `sparkR.session` to let SparkR know where the existing Spark installation is. ```{r, eval=FALSE} sparkR.session(sparkHome = "/HOME/spark") @@ -161,7 +183,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -211,7 +233,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -287,6 +309,21 @@ numCyl <- summarize(groupBy(carsDF, carsDF$cyl), count = n(carsDF$cyl)) head(numCyl) ``` +Use `cube` or `rollup` to compute subtotals across multiple dimensions. + +```{r} +mean(cube(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for {(`cyl`, `gear`, `am`), (`cyl`, `gear`), (`cyl`), ()}, while + +```{r} +mean(rollup(carsDF, "cyl", "gear", "am"), "mpg") +``` + +generates groupings for all possible combinations of grouping columns. + + #### Operating on Columns SparkR also provides a number of functions that can directly applied to columns for data processing and during aggregation. The example below shows the use of basic arithmetic functions. @@ -343,7 +380,7 @@ out <- dapply(carsSubDF, function(x) { x <- cbind(x, x$mpg * 1.61) }, schema) head(collect(out)) ``` -Like `dapply`, apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `dapply`, `dapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back. The output of the function should be a `data.frame`, but no schema is required in this case. Note that `dapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} out <- dapplyCollect( @@ -369,7 +406,7 @@ result <- gapply( head(arrange(result, "max_mpg", decreasing = TRUE)) ``` -Like gapply, `gapplyCollect` applies a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of UDF run on all the partition cannot be pulled to the driver and fit in driver memory. +Like `gapply`, `gapplyCollect` can apply a function to each partition of a `SparkDataFrame` and collect the result back to R `data.frame`. The output of the function should be a `data.frame` but no schema is required in this case. Note that `gapplyCollect` can fail if the output of the UDF on all partitions cannot be pulled into the driver's memory. ```{r} result <- gapplyCollect( @@ -422,20 +459,20 @@ options(ops) ### SQL Queries -A `SparkDataFrame` can also be registered as a temporary view in Spark SQL and that allows you to run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. +A `SparkDataFrame` can also be registered as a temporary view in Spark SQL so that one can run SQL queries over its data. The sql function enables applications to run SQL queries programmatically and returns the result as a `SparkDataFrame`. ```{r} people <- read.df(paste0(sparkR.conf("spark.home"), "/examples/src/main/resources/people.json"), "json") ``` -Register this SparkDataFrame as a temporary view. +Register this `SparkDataFrame` as a temporary view. ```{r} createOrReplaceTempView(people, "people") ``` -SQL statements can be run by using the sql method. +SQL statements can be run using the sql method. ```{r} teenagers <- sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") head(teenagers) @@ -446,25 +483,51 @@ head(teenagers) SparkR supports the following machine learning models and algorithms. -* Generalized Linear Model (GLM) +#### Classification -* Naive Bayes Model +* Linear Support Vector Machine (SVM) Classifier -* $k$-means Clustering +* Logistic Regression + +* Multilayer Perceptron (MLP) + +* Naive Bayes + +#### Regression * Accelerated Failure Time (AFT) Survival Model +* Generalized Linear Model (GLM) + +* Isotonic Regression + +#### Tree - Classification and Regression + +* Gradient-Boosted Trees (GBT) + +* Random Forest + +#### Clustering + +* Bisecting $k$-means + * Gaussian Mixture Model (GMM) +* $k$-means Clustering + * Latent Dirichlet Allocation (LDA) -* Multilayer Perceptron Model +#### Collaborative Filtering + +* Alternating Least Squares (ALS) + +#### Frequent Pattern Mining -* Collaborative Filtering with Alternating Least Squares (ALS) +* FP-growth -* Isotonic Regression Model +#### Statistics -More will be added in the future. +* Kolmogorov-Smirnov Test ### R Formula @@ -489,9 +552,137 @@ count(carsDF_test) head(carsDF_test) ``` - ### Models and Algorithms +#### Linear Support Vector Machine (SVM) Classifier + +[Linear Support Vector Machine (SVM)](https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM) classifier is an SVM classifier with linear kernels. +This is a binary classifier. We use a simple example to show how to use `spark.svmLinear` +for binary classification. + +```{r} +# load training data and create a DataFrame +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +# fit a Linear SVM classifier model +model <- spark.svmLinear(training, Survived ~ ., regParam = 0.01, maxIter = 10) +summary(model) +``` + +Predict values on training data +```{r} +prediction <- predict(model, training) +``` + +#### Logistic Regression + +[Logistic regression](https://en.wikipedia.org/wiki/Logistic_regression) is a widely-used model when the response is categorical. It can be seen as a special case of the [Generalized Linear Predictive Model](https://en.wikipedia.org/wiki/Generalized_linear_model). +We provide `spark.logit` on top of `spark.glm` to support logistic regression with advanced hyper-parameters. +It supports both binary and multiclass classification with elastic-net regularization and feature standardization, similar to `glmnet`. + +We use a simple example to demonstrate `spark.logit` usage. In general, there are three steps of using `spark.logit`: +1). Create a dataframe from a proper data source; 2). Fit a logistic regression model using `spark.logit` with a proper parameter setting; +and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. + +Binomial logistic regression +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.logit(training, Survived ~ ., regParam = 0.04741301) +summary(model) +``` + +Predict values on training data +```{r} +fitted <- predict(model, training) +``` + +Multinomial logistic regression against three classes +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +# Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. +model <- spark.logit(training, Class ~ ., regParam = 0.07815179) +summary(model) +``` + +#### Multilayer Perceptron + +Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: +$$ +y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). +$$ + +Nodes in intermediate layers use sigmoid (logistic) function: +$$ +f(z_i) = \frac{1}{1+e^{-z_i}}. +$$ + +Nodes in the output layer use softmax function: +$$ +f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. +$$ + +The number of nodes $N$ in the output layer corresponds to the number of classes. + +MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. + +`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. + +We use Titanic data set to show how to use `spark.mlp` in classification. +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +# fit a Multilayer Perceptron Classification Model +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) +``` + +To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. +```{r, include=FALSE} +ops <- options() +options(max.print=5) +``` +```{r} +# check the summary of the fitted model +summary(model) +``` +```{r, include=FALSE} +options(ops) +``` +```{r} +# make predictions use the fitted model +predictions <- predict(model, training) +head(select(predictions, predictions$prediction)) +``` + +#### Naive Bayes + +Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. + +```{r} +titanic <- as.data.frame(Titanic) +titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) +naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) +summary(naiveBayesModel) +naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) +head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +``` + +#### Accelerated Failure Time Survival Model + +Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. + +Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + +```{r, warning=FALSE} +library(survival) +ovarianDF <- createDataFrame(ovarian) +aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) +summary(aftModel) +aftPredictions <- predict(aftModel, ovarianDF) +head(aftPredictions) +``` + #### Generalized Linear Model The main function is `spark.glm`. The following families and link functions are supported. The default is gaussian. @@ -502,6 +693,7 @@ gaussian | identity, log, inverse binomial | logit, probit, cloglog (complementary log-log) poisson | log, identity, sqrt gamma | inverse, identity, log +tweedie | power link function There are three ways to specify the `family` argument. @@ -509,7 +701,11 @@ There are three ways to specify the `family` argument. * Family function, e.g. `family = binomial`. -* Result returned by a family function, e.g. `family = poisson(link = log)` +* Result returned by a family function, e.g. `family = poisson(link = log)`. + +* Note that there are two ways to specify the tweedie family: + a) Set `family = "tweedie"` and specify the `var.power` and `link.power` + b) When package `statmod` is loaded, the tweedie family is specified using the family definition therein, i.e., `tweedie()`. For more information regarding the families and their link functions, see the Wikipedia page [Generalized Linear Model](https://en.wikipedia.org/wiki/Generalized_linear_model). @@ -525,50 +721,107 @@ gaussianFitted <- predict(gaussianGLM, carsDF) head(select(gaussianFitted, "model", "prediction", "mpg", "wt", "hp")) ``` -#### Naive Bayes Model +The following is the same fit using the tweedie family: +```{r} +tweedieGLM1 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 0.0) +summary(tweedieGLM1) +``` +We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: +```{r} +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", + var.power = 1.2, link.power = 0.0) +summary(tweedieGLM2) +``` + +#### Isotonic Regression -Naive Bayes model assumes independence among the features. `spark.naiveBayes` fits a [Bernoulli naive Bayes model](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Bernoulli_naive_Bayes) against a SparkDataFrame. The data should be all categorical. These models are often used for document classification. +`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize +$$ +\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. +$$ + +There are a few more arguments that may be useful. + +* `weightCol`: a character string specifying the weight column. + +* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). + +* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. + +We use an artificial example to show the use. ```{r} -titanic <- as.data.frame(Titanic) -titanicDF <- createDataFrame(titanic[titanic$Freq > 0, -5]) -naiveBayesModel <- spark.naiveBayes(titanicDF, Survived ~ Class + Sex + Age) -summary(naiveBayesModel) -naiveBayesPrediction <- predict(naiveBayesModel, titanicDF) -head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction")) +y <- c(3.0, 6.0, 8.0, 5.0, 7.0) +x <- c(1.0, 2.0, 3.5, 3.0, 4.0) +w <- rep(1.0, 5) +data <- data.frame(y = y, x = x, w = w) +df <- createDataFrame(data) +isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") +isoregFitted <- predict(isoregModel, df) +head(select(isoregFitted, "x", "y", "prediction")) ``` -#### k-Means Clustering +In the prediction stage, based on the fitted monotone piecewise function, the rules are: -`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. +* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. + +* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. + +* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. + +For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. ```{r} -kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) -summary(kmeansModel) -kmeansPredictions <- predict(kmeansModel, carsDF) -head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) +head(predict(isoregModel, newDF)) ``` -#### AFT Survival Model -Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. +#### Gradient-Boosted Trees + +`spark.gbt` fits a [gradient-boosted tree](https://en.wikipedia.org/wiki/Gradient_boosting) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +We use the `longley` dataset to train a gradient-boosted tree and make predictions: -Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. ```{r, warning=FALSE} -library(survival) -ovarianDF <- createDataFrame(ovarian) -aftModel <- spark.survreg(ovarianDF, Surv(futime, fustat) ~ ecog_ps + rx) -summary(aftModel) -aftPredictions <- predict(aftModel, ovarianDF) -head(aftPredictions) +df <- createDataFrame(longley) +gbtModel <- spark.gbt(df, Employed ~ ., type = "regression", maxDepth = 2, maxIter = 2) +summary(gbtModel) +predictions <- predict(gbtModel, df) ``` -#### Gaussian Mixture Model +#### Random Forest + +`spark.randomForest` fits a [random forest](https://en.wikipedia.org/wiki/Random_forest) classification or regression model on a `SparkDataFrame`. +Users can call `summary` to get a summary of the fitted model, `predict` to make predictions, and `write.ml`/`read.ml` to save/load fitted models. + +In the following example, we use the `longley` dataset to train a random forest and make predictions: + +```{r, warning=FALSE} +df <- createDataFrame(longley) +rfModel <- spark.randomForest(df, Employed ~ ., type = "regression", maxDepth = 2, numTrees = 2) +summary(rfModel) +predictions <- predict(rfModel, df) +``` -(Coming in 2.1.0) +#### Bisecting k-Means + +`spark.bisectingKmeans` is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) using a divisive (or "top-down") approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. + +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) +summary(model) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) +``` + +#### Gaussian Mixture Model `spark.gaussianMixture` fits multivariate [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model) (GMM) against a `SparkDataFrame`. [Expectation-Maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm) (EM) is used to approximate the maximum likelihood estimator (MLE) of the model. -We use a simulated example to demostrate the usage. +We use a simulated example to demonstrate the usage. ```{r} X1 <- data.frame(V1 = rnorm(4), V2 = rnorm(4)) X2 <- data.frame(V1 = rnorm(6, 3), V2 = rnorm(6, 4)) @@ -580,10 +833,18 @@ gmmFitted <- predict(gmmModel, df) head(select(gmmFitted, "V1", "V2", "prediction")) ``` +#### k-Means Clustering -#### Latent Dirichlet Allocation +`spark.kmeans` fits a $k$-means clustering model against a `SparkDataFrame`. As an unsupervised learning method, we don't need a response variable. Hence, the left hand side of the R formula should be left blank. The clustering is based only on the variables on the right hand side. -(Coming in 2.1.0) +```{r} +kmeansModel <- spark.kmeans(carsDF, ~ mpg + hp + wt, k = 3) +summary(kmeansModel) +kmeansPredictions <- predict(kmeansModel, carsDF) +head(select(kmeansPredictions, "model", "mpg", "hp", "wt", "prediction"), n = 20L) +``` + +#### Latent Dirichlet Allocation `spark.lda` fits a [Latent Dirichlet Allocation](https://en.wikipedia.org/wiki/Latent_Dirichlet_allocation) model on a `SparkDataFrame`. It is often used in topic modeling in which topics are inferred from a collection of text documents. LDA can be thought of as a clustering algorithm as follows: @@ -591,30 +852,14 @@ head(select(gmmFitted, "V1", "V2", "prediction")) * Topics and documents both exist in a feature space, where feature vectors are vectors of word counts (bag of words). -* Rather than estimating a clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. +* Rather than clustering using a traditional distance, LDA uses a function based on a statistical model of how text documents are generated. -To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two type options for the column: +To use LDA, we need to specify a `features` column in `data` where each entry represents a document. There are two options for the column: * character string: This can be a string of the whole document. It will be parsed automatically. Additional stop words can be added in `customizedStopWords`. * libSVM: Each entry is a collection of words and will be processed directly. -There are several parameters LDA takes for fitting the model. - -* `k`: number of topics (default 10). - -* `maxIter`: maximum iterations (default 20). - -* `optimizer`: optimizer to train an LDA model, "online" (default) uses [online variational inference](https://www.cs.princeton.edu/~blei/papers/HoffmanBleiBach2010b.pdf). "em" uses [expectation-maximization](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm). - -* `subsamplingRate`: For `optimizer = "online"`. Fraction of the corpus to be sampled and used in each iteration of mini-batch gradient descent, in range (0, 1] (default 0.05). - -* `topicConcentration`: concentration parameter (commonly named beta or eta) for the prior placed on topic distributions over terms, default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective topicConcentration. Only 1-size numeric is accepted. - -* `docConcentration`: concentration parameter (commonly named alpha) for the prior placed on documents distributions over topics (theta), default -1 to set automatically on the Spark side. Use `summary` to retrieve the effective docConcentration. Only 1-size or k-size numeric is accepted. - -* `maxVocabSize`: maximum vocabulary size, default 1 << 18. - Two more functions are provided for the fitted model. * `spark.posterior` returns a `SparkDataFrame` containing a column of posterior probabilities vectors named "topicDistribution". @@ -653,53 +898,13 @@ perplexity <- spark.perplexity(model, corpusDF) perplexity ``` - -#### Multilayer Perceptron - -(Coming in 2.1.0) - -Multilayer perceptron classifier (MLPC) is a classifier based on the [feedforward artificial neural network](https://en.wikipedia.org/wiki/Feedforward_neural_network). MLPC consists of multiple layers of nodes. Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes map inputs to outputs by a linear combination of the inputs with the node’s weights $w$ and bias $b$ and applying an activation function. This can be written in matrix form for MLPC with $K+1$ layers as follows: -$$ -y(x)=f_K(\ldots f_2(w_2^T f_1(w_1^T x + b_1) + b_2) \ldots + b_K). -$$ - -Nodes in intermediate layers use sigmoid (logistic) function: -$$ -f(z_i) = \frac{1}{1+e^{-z_i}}. -$$ - -Nodes in the output layer use softmax function: -$$ -f(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}. -$$ - -The number of nodes $N$ in the output layer corresponds to the number of classes. - -MLPC employs backpropagation for learning the model. We use the logistic loss function for optimization and L-BFGS as an optimization routine. - -`spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. According to the description above, there are several additional parameters that can be set: - -* `layers`: integer vector containing the number of nodes for each layer. - -* `solver`: solver parameter, supported options: `"gd"` (minibatch gradient descent) or `"l-bfgs"`. - -* `maxIter`: maximum iteration number. - -* `tol`: convergence tolerance of iterations. - -* `stepSize`: step size for `"gd"`. - -* `seed`: seed parameter for weights initialization. - -#### Collaborative Filtering - -(Coming in 2.1.0) +#### Alternating Least Squares `spark.als` learns latent factors in [collaborative filtering](https://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) via [alternating least squares](http://dl.acm.org/citation.cfm?id=1608614). -There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. +There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, and `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -707,7 +912,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -717,64 +922,71 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` -#### Isotonic Regression Model - -(Coming in 2.1.0) +#### FP-growth -`spark.isoreg` fits an [Isotonic Regression](https://en.wikipedia.org/wiki/Isotonic_regression) model against a `SparkDataFrame`. It solves a weighted univariate a regression problem under a complete order constraint. Specifically, given a set of real observed responses $y_1, \ldots, y_n$, corresponding real features $x_1, \ldots, x_n$, and optionally positive weights $w_1, \ldots, w_n$, we want to find a monotone (piecewise linear) function $f$ to minimize -$$ -\ell(f) = \sum_{i=1}^n w_i (y_i - f(x_i))^2. -$$ +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. -There are a few more arguments that may be useful. +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") -* `weightCol`: a character string specifying the weight column. +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` -* `isotonic`: logical value indicating whether the output sequence should be isotonic/increasing (`TRUE`) or antitonic/decreasing (`FALSE`). +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. -* `featureIndex`: the index of the feature on the right hand side of the formula if it is a vector column (default: 0), no effect otherwise. +```{r} +head(spark.freqItemsets(fpm)) +``` -We use an artificial example to show the use. +`spark.associationRules` returns a `SparkDataFrame` with the association rules. ```{r} -y <- c(3.0, 6.0, 8.0, 5.0, 7.0) -x <- c(1.0, 2.0, 3.5, 3.0, 4.0) -w <- rep(1.0, 5) -data <- data.frame(y = y, x = x, w = w) -df <- createDataFrame(data) -isoregModel <- spark.isoreg(df, y ~ x, weightCol = "w") -isoregFitted <- predict(isoregModel, df) -head(select(isoregFitted, "x", "y", "prediction")) +head(spark.associationRules(fpm)) ``` -In the prediction stage, based on the fitted monotone piecewise function, the rules are: +We can make predictions based on the `antecedent`. -* If the prediction input exactly matches a training feature then associated prediction is returned. In case there are multiple predictions with the same feature then one of them is returned. Which one is undefined. +```{r} +head(predict(fpm, df)) +``` -* If the prediction input is lower or higher than all training features then prediction with lowest or highest feature is returned respectively. In case there are multiple predictions with the same feature then the lowest or highest is returned respectively. +#### Kolmogorov-Smirnov Test -* If the prediction input falls between two training features then prediction is treated as piecewise linear function and interpolated value is calculated from the predictions of the two closest features. In case there are multiple values with the same feature then the same rules as in previous point are used. +`spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). +Given a `SparkDataFrame`, the test compares continuous data in a given column `testCol` with the theoretical distribution +specified by parameter `nullHypothesis`. +Users can call `summary` to get a summary of the test results. -For example, when the input is $3.2$, the two closest feature values are $3.0$ and $3.5$, then predicted value would be a linear interpolation between the predicted values at $3.0$ and $3.5$. +In the following example, we test whether the `longley` dataset's `Armed_Forces` column +follows a normal distribution. We set the parameters of the normal distribution using +the mean and standard deviation of the sample. -```{r} -newDF <- createDataFrame(data.frame(x = c(1.5, 3.2))) -head(predict(isoregModel, newDF)) +```{r, warning=FALSE} +df <- createDataFrame(longley) +afStats <- head(select(df, mean(df$Armed_Forces), sd(df$Armed_Forces))) +afMean <- afStats[1] +afStd <- afStats[2] + +test <- spark.kstest(df, "Armed_Forces", "norm", c(afMean, afStd)) +testSummary <- summary(test) +testSummary ``` -#### What's More? -We also expect Decision Tree, Random Forest, Kolmogorov-Smirnov Test coming in the next version 2.1.0. ### Model Persistence -The following example shows how to save/load an ML model by SparkR. -```{r, warning=FALSE} -irisDF <- createDataFrame(iris) -gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +The following example shows how to save/load an ML model in SparkR. +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +gaussianGLM <- spark.glm(training, Freq ~ Sex + Age, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") @@ -785,13 +997,79 @@ gaussianGLM2 <- read.ml(modelPath) summary(gaussianGLM2) # Check model prediction -gaussianPredictions <- predict(gaussianGLM2, irisDF) +gaussianPredictions <- predict(gaussianGLM2, training) head(gaussianPredictions) unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes @@ -802,19 +1080,19 @@ There are three main object classes in SparkR you may be working with. + `sdf` stores a reference to the corresponding Spark Dataset in the Spark JVM backend. + `env` saves the meta-information of the object such as `isCached`. -It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. + It can be created by data import methods or by transforming an existing `SparkDataFrame`. We can manipulate `SparkDataFrame` by numerous data processing functions and feed that into machine learning algorithms. -* `Column`: an S4 class representing column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding Column object in the Spark JVM backend. +* `Column`: an S4 class representing a column of `SparkDataFrame`. The slot `jc` saves a reference to the corresponding `Column` object in the Spark JVM backend. -It can be obtained from a `SparkDataFrame` by `$` operator, `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. + It can be obtained from a `SparkDataFrame` by `$` operator, e.g., `df$col`. More often, it is used together with other functions, for example, with `select` to select particular columns, with `filter` and constructed conditions to select rows, with aggregation functions to compute aggregate statistics for each group. -* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a RelationalGroupedDataset object in the backend. +* `GroupedData`: an S4 class representing grouped data created by `groupBy` or by transforming other `GroupedData`. Its `sgd` slot saves a reference to a `RelationalGroupedDataset` object in the backend. -This is often an intermediate object with group information and followed up by aggregation operations. + This is often an intermediate object with group information and followed up by aggregation operations. ### Architecture -A complete description of architecture can be seen in reference, in particular the paper *SparkR: Scaling R Programs with Spark*. +A complete description of architecture can be seen in the references, in particular the paper *SparkR: Scaling R Programs with Spark*. Under the hood of SparkR is Spark SQL engine. This avoids the overheads of running interpreted R code, and the optimized SQL execution engine in Spark uses structural information about data and computation flow to perform a bunch of optimizations to speed up the computation. @@ -822,9 +1100,9 @@ The main method calls of actual computation happen in the Spark JVM of the drive Two kinds of RPCs are supported in the SparkR JVM backend: method invocation and creating new objects. Method invocation can be done in two ways. -* `sparkR.invokeJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. +* `sparkR.callJMethod` takes a reference to an existing Java object and a list of arguments to be passed on to the method. -* `sparkR.invokeJStatic` takes a class name for static method and a list of arguments to be passed on to the method. +* `sparkR.callJStatic` takes a class name for static method and a list of arguments to be passed on to the method. The arguments are serialized using our custom wire format which is then deserialized on the JVM side. We then use Java reflection to invoke the appropriate method. diff --git a/R/run-tests.sh b/R/run-tests.sh index 5e4dafaf76f3d..29764f48bd156 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.default.name="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/README.md b/README.md index dd7d0e22495b3..1e521a7e7b178 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,7 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html) -and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). +guide, on the [project web page](http://spark.apache.org/documentation.html). This README file only contains basic setup instructions. ## Building Spark @@ -29,8 +28,8 @@ To build Spark and its example programs, run: You can build Spark using more than one thread by using the -T option with Maven, see ["Parallel builds in Maven 3"](https://cwiki.apache.org/confluence/display/MAVEN/Parallel+builds+in+Maven+3). More detailed documentation is available from the project site, at ["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). -For developing Spark using an IDE, see [Eclipse](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-Eclipse) -and [IntelliJ](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools#UsefulDeveloperTools-IntelliJ). + +For general development tips, including info on developing Spark using an IDE, see ["Useful Developer Tools"](http://spark.apache.org/developer-tools.html). ## Interactive Scala Shell @@ -80,7 +79,7 @@ can be run using: ./dev/run-tests Please see the guidance on how to -[run tests for a module, or individual tests](https://cwiki.apache.org/confluence/display/SPARK/Useful+Developer+Tools). +[run tests for a module, or individual tests](http://spark.apache.org/developer-tools.html#individual-tests). ## A Note About Hadoop Versions @@ -98,7 +97,7 @@ building for particular Hive and Hive Thriftserver distributions. Please refer to the [Configuration Guide](http://spark.apache.org/docs/latest/configuration.html) in the online documentation for an overview on how to configure Spark. -## Contributing +## Contributing -Please review the [Contribution to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) -wiki for information on how to get started contributing to the project. +Please review the [Contribution to Spark guide](http://spark.apache.org/contributing.html) +for information on how to get started contributing to the project. diff --git a/appveyor.yml b/appveyor.yml index 5e756835bcb9b..58c2e98289e96 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -27,6 +27,9 @@ branches: only_commits: files: - R/ + - sql/core/src/main/scala/org/apache/spark/sql/api/r/ + - core/src/main/scala/org/apache/spark/api/r/ + - mllib/src/main/scala/org/apache/spark/ml/r/ cache: - C:\Users\appveyor\.m2 @@ -43,14 +46,16 @@ install: - cmd: R -e "packageVersion('survival')" build_script: - - cmd: mvn -DskipTests -Phadoop-2.6 -Psparkr -Phive -Phive-thriftserver package + - cmd: mvn -DskipTests -Psparkr -Phive -Phive-thriftserver package + +environment: + NOT_CRAN: true test_script: - - cmd: .\bin\spark-submit2.cmd --conf spark.hadoop.fs.default.name="file:///" R\pkg\tests\run-all.R + - cmd: .\bin\spark-submit2.cmd --driver-java-options "-Dlog4j.configuration=file:///%CD:\=/%/R/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" R\pkg\tests\run-all.R notifications: - provider: Email on_build_success: false on_build_failure: false on_build_status_changed: false - diff --git a/assembly/README b/assembly/README index 14a5ff8dfc78f..d5dafab477410 100644 --- a/assembly/README +++ b/assembly/README @@ -9,4 +9,4 @@ This module is off by default. To activate it specify the profile in the command If you need to build an assembly for a different version of Hadoop the hadoop-version system property needs to be set as in this example: - -Dhadoop.version=2.0.6-alpha + -Dhadoop.version=2.7.3 diff --git a/assembly/pom.xml b/assembly/pom.xml index ec243eaebaea7..464af16e46f6e 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml @@ -187,6 +187,7 @@ org.apache.maven.plugins maven-assembly-plugin + 3.0.0 dist @@ -225,5 +226,19 @@ provided + + + + hadoop-cloud + + + org.apache.spark + spark-hadoop-cloud_${scala.binary.version} + ${project.version} + + + diff --git a/bin/beeline b/bin/beeline index 1627626941a73..058534699e44b 100755 --- a/bin/beeline +++ b/bin/beeline @@ -25,7 +25,7 @@ set -o posix # Figure out if SPARK_HOME is set if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi CLASS="org.apache.hive.beeline.BeeLine" diff --git a/bin/find-spark-home b/bin/find-spark-home new file mode 100755 index 0000000000000..fa78407d4175a --- /dev/null +++ b/bin/find-spark-home @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Attempts to find a proper value for SPARK_HOME. Should be included using "source" directive. + +FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" + +# Short cirtuit if the user already has this set. +if [ ! -z "${SPARK_HOME}" ]; then + exit 0 +elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then + # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't + # need to search the different Python directories for a Spark installation. + # Note only that, if the user has pip installed PySpark but is directly calling pyspark-shell or + # spark-submit in another directory we want to use that version of PySpark rather than the + # pip installed version of PySpark. + export SPARK_HOME="$(cd "$(dirname "$0")"/..; pwd)" +else + # We are pip installed, use the Python script to resolve a reasonable SPARK_HOME + # Default to standard python interpreter unless told otherwise + if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"python"}" + fi + export SPARK_HOME=$($PYSPARK_DRIVER_PYTHON "$FIND_SPARK_HOME_PYTHON_SCRIPT") +fi diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index eaea964ed5b3d..8a2f709960a25 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -23,7 +23,7 @@ # Figure out where Spark is installed if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi if [ -z "$SPARK_ENV_LOADED" ]; then diff --git a/bin/pyspark b/bin/pyspark index d6b3ab0a44321..98387c2ec5b8a 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh @@ -46,7 +46,7 @@ WORKS_WITH_IPYTHON=$(python -c 'import sys; print(sys.version_info >= (2, 7, 0)) # Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! $WORKS_WITH_IPYTHON ]]; then echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 exit 1 else @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR export PYTHONHASHSEED=0 - exec "$PYSPARK_DRIVER_PYTHON" -m $1 + exec "$PYSPARK_DRIVER_PYTHON" -m "$1" exit fi diff --git a/bin/run-example b/bin/run-example index dd0e3c4120260..4ba5399311d33 100755 --- a/bin/run-example +++ b/bin/run-example @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/run-example [options] example-class [example args]" diff --git a/bin/spark-class b/bin/spark-class index 377c8d1add3f6..65d3b9612909a 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi . "${SPARK_HOME}"/bin/load-spark-env.sh @@ -27,7 +27,7 @@ fi if [ -n "${JAVA_HOME}" ]; then RUNNER="${JAVA_HOME}/bin/java" else - if [ `command -v java` ]; then + if [ "$(command -v java)" ]; then RUNNER="java" else echo "JAVA_HOME is not set" >&2 @@ -36,7 +36,7 @@ else fi # Find Spark jars. -if [ -f "${SPARK_HOME}/RELEASE" ]; then +if [ -d "${SPARK_HOME}/jars" ]; then SPARK_JARS_DIR="${SPARK_HOME}/jars" else SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" @@ -72,6 +72,8 @@ build_command() { printf "%d\0" $? } +# Turn off posix mode since it does not allow process substitution +set +o posix CMD=() while IFS= read -d '' -r ARG; do CMD+=("$ARG") diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 869c0b202f7f3..f6157f42843e8 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -50,7 +50,16 @@ if not "x%SPARK_PREPEND_CLASSES%"=="x" ( rem Figure out where java is. set RUNNER=java -if not "x%JAVA_HOME%"=="x" set RUNNER=%JAVA_HOME%\bin\java +if not "x%JAVA_HOME%"=="x" ( + set RUNNER=%JAVA_HOME%\bin\java +) else ( + where /q "%RUNNER%" + if ERRORLEVEL 1 ( + echo Java not found and JAVA_HOME environment variable is not set. + echo Install Java and set JAVA_HOME to point to the Java installation directory. + exit /b 1 + ) +) rem The launcher library prints the command to be executed in a single line suitable for being rem executed by the batch interpreter. So read all the output of the launcher into a variable. diff --git a/bin/spark-shell b/bin/spark-shell index 6583b5bd880ee..421f36cac3d47 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -21,7 +21,7 @@ # Shell script for starting the Spark Shell REPL cygwin=false -case "`uname`" in +case "$(uname)" in CYGWIN*) cygwin=true;; esac @@ -29,7 +29,7 @@ esac set -o posix if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-shell [options]" diff --git a/bin/spark-sql b/bin/spark-sql index 970d12cbf51dd..b08b944ebd319 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi export _SPARK_CMD_USAGE="Usage: ./bin/spark-sql [options] [cli option]" diff --git a/bin/spark-submit b/bin/spark-submit index 023f9c162f4b8..4e9d3614e6370 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi # disable randomized hash for string in Python 3.3+ diff --git a/bin/sparkR b/bin/sparkR index 2c07a82e2173b..29ab10df8ab6d 100755 --- a/bin/sparkR +++ b/bin/sparkR @@ -18,7 +18,7 @@ # if [ -z "${SPARK_HOME}" ]; then - export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" + source "$(dirname "$0")"/find-spark-home fi source "${SPARK_HOME}"/bin/load-spark-env.sh diff --git a/build/mvn b/build/mvn index c3ab62da36868..1e393c331dd8b 100755 --- a/build/mvn +++ b/build/mvn @@ -22,7 +22,7 @@ _DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # Preserve the calling directory _CALLING_DIR="$(pwd)" # Options used during compilation -_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" +_COMPILE_JVM_OPTS="-Xmx2g -XX:ReservedCodeCacheSize=512m" # Installs any application tarball given a URL, the expected tarball name, # and, optionally, a checkable binary path to determine if the binary has @@ -91,13 +91,13 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.9/bin/zinc" + local zinc_path="zinc-0.3.11/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} install_app \ - "${TYPESAFE_MIRROR}/zinc/0.3.9" \ - "zinc-0.3.9.tgz" \ + "${TYPESAFE_MIRROR}/zinc/0.3.11" \ + "zinc-0.3.11.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } @@ -141,13 +141,9 @@ cd "${_CALLING_DIR}" # Now that zinc is ensured to be installed, check its status and, if its # not running or just installed, start it if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`"${ZINC_BIN}" -status -port ${ZINC_PORT}`" ]; then - ZINC_JAVA_HOME= - if [ -n "$JAVA_7_HOME" ]; then - ZINC_JAVA_HOME="env JAVA_HOME=$JAVA_7_HOME" - fi export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"} "${ZINC_BIN}" -shutdown -port ${ZINC_PORT} - $ZINC_JAVA_HOME "${ZINC_BIN}" -start -port ${ZINC_PORT} \ + "${ZINC_BIN}" -start -port ${ZINC_PORT} \ -scala-compiler "${SCALA_COMPILER}" \ -scala-library "${SCALA_LIBRARY}" &>/dev/null fi diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index 615f848394650..4732669ee651f 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -117,7 +117,7 @@ get_mem_opts () { (( $perm < 4096 )) || perm=4096 local codecache=$(( $perm / 2 )) - echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" + echo "-Xms${mem}m -Xmx${mem}m -XX:ReservedCodeCacheSize=${codecache}m" } require_arg () { diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index fcefe64d59c91..066970f24205f 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -76,6 +76,10 @@ guava compile + + org.apache.commons + commons-crypto + @@ -87,6 +91,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.mockito mockito-core diff --git a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java index 5b69e2bb03546..965c4ae307667 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/common/network-common/src/main/java/org/apache/spark/network/TransportContext.java @@ -17,9 +17,9 @@ package org.apache.spark.network; +import java.util.ArrayList; import java.util.List; -import com.google.common.collect.Lists; import io.netty.channel.Channel; import io.netty.channel.socket.SocketChannel; import io.netty.handler.timeout.IdleStateHandler; @@ -62,8 +62,20 @@ public class TransportContext { private final RpcHandler rpcHandler; private final boolean closeIdleConnections; - private final MessageEncoder encoder; - private final MessageDecoder decoder; + /** + * Force to create MessageEncoder and MessageDecoder so that we can make sure they will be created + * before switching the current context class loader to ExecutorClassLoader. + * + * Netty's MessageToMessageEncoder uses Javassist to generate a matcher class and the + * implementation calls "Class.forName" to check if this calls is already generated. If the + * following two objects are created in "ExecutorClassLoader.findClass", it will cause + * "ClassCircularityError". This is because loading this Netty generated class will call + * "ExecutorClassLoader.findClass" to search this class, and "ExecutorClassLoader" will try to use + * RPC to load it and cause to load the non-exist matcher class again. JVM will report + * `ClassCircularityError` to prevent such infinite recursion. (See SPARK-17714) + */ + private static final MessageEncoder ENCODER = MessageEncoder.INSTANCE; + private static final MessageDecoder DECODER = MessageDecoder.INSTANCE; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { this(conf, rpcHandler, false); @@ -75,8 +87,6 @@ public TransportContext( boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; - this.encoder = new MessageEncoder(); - this.decoder = new MessageDecoder(); this.closeIdleConnections = closeIdleConnections; } @@ -90,7 +100,7 @@ public TransportClientFactory createClientFactory(List } public TransportClientFactory createClientFactory() { - return createClientFactory(Lists.newArrayList()); + return createClientFactory(new ArrayList<>()); } /** Create a server which will attempt to bind to a specific port. */ @@ -110,7 +120,7 @@ public TransportServer createServer(List bootstraps) { } public TransportServer createServer() { - return createServer(0, Lists.newArrayList()); + return createServer(0, new ArrayList<>()); } public TransportChannelHandler initializePipeline(SocketChannel channel) { @@ -135,9 +145,9 @@ public TransportChannelHandler initializePipeline( try { TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler); channel.pipeline() - .addLast("encoder", encoder) + .addLast("encoder", ENCODER) .addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder()) - .addLast("decoder", decoder) + .addLast("decoder", DECODER) .addLast("idleStateHandler", new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000)) // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this // would require more logic to guarantee if this were not part of the same event loop. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 7e7d78d42a8fb..a6f527c118218 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -32,8 +32,6 @@ import com.google.common.base.Throwables; import com.google.common.util.concurrent.SettableFuture; import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -133,40 +131,36 @@ public void setClientId(String id) { */ public void fetchChunk( long streamId, - final int chunkIndex, - final ChunkReceivedCallback callback) { - final long startTime = System.currentTimeMillis(); + int chunkIndex, + ChunkReceivedCallback callback) { + long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } - final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); handler.addFetchRequest(streamChunkId, callback); - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request {} to {} took {} ms", streamChunkId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - handler.removeFetchRequest(streamChunkId); - channel.close(); - try { - callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(future -> { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + if (logger.isTraceEnabled()) { + logger.trace("Sending request {} to {} took {} ms", streamChunkId, + getRemoteAddress(channel), timeTaken); } - }); + } else { + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeFetchRequest(streamChunkId); + channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + }); } /** @@ -175,8 +169,8 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param streamId The stream to fetch. * @param callback Object to call with the stream data. */ - public void stream(final String streamId, final StreamCallback callback) { - final long startTime = System.currentTimeMillis(); + public void stream(String streamId, StreamCallback callback) { + long startTime = System.currentTimeMillis(); if (logger.isDebugEnabled()) { logger.debug("Sending stream request for {} to {}", streamId, getRemoteAddress(channel)); } @@ -186,29 +180,25 @@ public void stream(final String streamId, final StreamCallback callback) { // when responses arrive. synchronized (this) { handler.addStreamCallback(callback); - channel.writeAndFlush(new StreamRequest(streamId)).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - long timeTaken = System.currentTimeMillis() - startTime; - if (logger.isTraceEnabled()) { - logger.trace("Sending request for {} to {} took {} ms", streamId, - getRemoteAddress(channel), timeTaken); - } - } else { - String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, - getRemoteAddress(channel), future.cause()); - logger.error(errorMsg, future.cause()); - channel.close(); - try { - callback.onFailure(streamId, new IOException(errorMsg, future.cause())); - } catch (Exception e) { - logger.error("Uncaught exception in RPC response callback handler!", e); - } - } + channel.writeAndFlush(new StreamRequest(streamId)).addListener(future -> { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + if (logger.isTraceEnabled()) { + logger.trace("Sending request for {} to {} took {} ms", streamId, + getRemoteAddress(channel), timeTaken); } - }); + } else { + String errorMsg = String.format("Failed to send request for %s to %s: %s", streamId, + getRemoteAddress(channel), future.cause()); + logger.error(errorMsg, future.cause()); + channel.close(); + try { + callback.onFailure(streamId, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + }); } } @@ -220,19 +210,17 @@ public void operationComplete(ChannelFuture future) throws Exception { * @param callback Callback to handle the RPC's reply. * @return The RPC's id. */ - public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) { - final long startTime = System.currentTimeMillis(); + public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { + long startTime = System.currentTimeMillis(); if (logger.isTraceEnabled()) { logger.trace("Sending RPC to {}", getRemoteAddress(channel)); } - final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); handler.addRpcRequest(requestId, callback); - channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { + channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))) + .addListener(future -> { if (future.isSuccess()) { long timeTaken = System.currentTimeMillis() - startTime; if (logger.isTraceEnabled()) { @@ -251,8 +239,7 @@ public void operationComplete(ChannelFuture future) throws Exception { logger.error("Uncaught exception in RPC response callback handler!", e); } } - } - }); + }); return requestId; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index e895f13f45458..b50e043d5c9ce 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -100,8 +100,10 @@ public TransportClientFactory( IOMode ioMode = IOMode.valueOf(conf.ioMode()); this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); - // TODO: Make thread pool name configurable. - this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + this.workerGroup = NettyUtils.createEventLoop( + ioMode, + conf.clientThreads(), + conf.getModuleName() + "-client"); this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); } @@ -120,7 +122,8 @@ public TransportClientFactory( * * Concurrency: This method is safe to call from multiple threads. */ - public TransportClient createClient(String remoteHost, int remotePort) throws IOException { + public TransportClient createClient(String remoteHost, int remotePort) + throws IOException, InterruptedException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. // Use unresolved address here to avoid DNS resolution each time we creates a client. @@ -188,13 +191,14 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO * As with {@link #createClient(String, int)}, this method is blocking. */ public TransportClient createUnmanagedClient(String remoteHost, int remotePort) - throws IOException { + throws IOException, InterruptedException { final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); return createClient(address); } /** Create a completely new {@link TransportClient} to the remote address. */ - private TransportClient createClient(InetSocketAddress address) throws IOException { + private TransportClient createClient(InetSocketAddress address) + throws IOException, InterruptedException { logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); @@ -221,7 +225,7 @@ public void initChannel(SocketChannel ch) { // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); - if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + if (!cf.await(conf.connectionTimeoutMs())) { throw new IOException( String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); } else if (cf.cause() != null) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java new file mode 100644 index 0000000000000..799f4540aa934 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; + +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing authentication using Spark's auth protocol. + * + * This bootstrap falls back to using the SASL bootstrap if the server throws an error during + * authentication, and the configuration allows it. This is used for backwards compatibility + * with external shuffle services that do not support the new protocol. + * + * It also automatically falls back to SASL if the new encryption backend is disabled, so that + * callers only need to install this bootstrap when authentication is enabled. + */ +public class AuthClientBootstrap implements TransportClientBootstrap { + + private static final Logger LOG = LoggerFactory.getLogger(AuthClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final String authUser; + private final SecretKeyHolder secretKeyHolder; + + public AuthClientBootstrap( + TransportConf conf, + String appId, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + // TODO: right now this behaves like the SASL backend, because when executors start up + // they don't necessarily know the app ID. So they send a hardcoded "user" that is defined + // in the SecurityManager, which will also always return the same secret (regardless of the + // user name). All that's needed here is for this "user" to match on both sides, since that's + // required by the protocol. At some point, though, it would be better for the actual app ID + // to be provided here. + this.appId = appId; + this.authUser = secretKeyHolder.getSaslUser(appId); + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public void doBootstrap(TransportClient client, Channel channel) { + if (!conf.encryptionEnabled()) { + LOG.debug("AES encryption disabled, using old auth protocol."); + doSaslAuth(client, channel); + return; + } + + try { + doSparkAuth(client, channel); + } catch (GeneralSecurityException | IOException e) { + throw Throwables.propagate(e); + } catch (RuntimeException e) { + // There isn't a good exception that can be caught here to know whether it's really + // OK to switch back to SASL (because the server doesn't speak the new protocol). So + // try it anyway, and in the worst case things will fail again. + if (conf.saslFallback()) { + LOG.warn("New auth protocol failed, trying SASL.", e); + doSaslAuth(client, channel); + } else { + throw e; + } + } + } + + private void doSparkAuth(TransportClient client, Channel channel) + throws GeneralSecurityException, IOException { + + String secretKey = secretKeyHolder.getSecretKey(authUser); + try (AuthEngine engine = new AuthEngine(authUser, secretKey, conf)) { + ClientChallenge challenge = engine.challenge(); + ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); + challenge.encode(challengeData); + + ByteBuffer responseData = + client.sendRpcSync(challengeData.nioBuffer(), conf.authRTTimeoutMs()); + ServerResponse response = ServerResponse.decodeMessage(responseData); + + engine.validate(response); + engine.sessionCipher().addToChannel(channel); + } + } + + private void doSaslAuth(TransportClient client, Channel channel) { + SaslClientBootstrap sasl = new SaslClientBootstrap(conf, appId, secretKeyHolder); + sasl.doBootstrap(client, channel); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java new file mode 100644 index 0000000000000..b769ebeba36cc --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.Closeable; +import java.io.IOException; +import java.math.BigInteger; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import java.util.Properties; +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.ShortBufferException; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Bytes; +import org.apache.commons.crypto.cipher.CryptoCipher; +import org.apache.commons.crypto.cipher.CryptoCipherFactory; +import org.apache.commons.crypto.random.CryptoRandom; +import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.util.TransportConf; + +/** + * A helper class for abstracting authentication and key negotiation details. This is used by + * both client and server sides, since the operations are basically the same. + */ +class AuthEngine implements Closeable { + + private static final Logger LOG = LoggerFactory.getLogger(AuthEngine.class); + private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 }); + + private final byte[] appId; + private final char[] secret; + private final TransportConf conf; + private final Properties cryptoConf; + private final CryptoRandom random; + + private byte[] authNonce; + + @VisibleForTesting + byte[] challenge; + + private TransportCipher sessionCipher; + private CryptoCipher encryptor; + private CryptoCipher decryptor; + + AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException { + this.appId = appId.getBytes(UTF_8); + this.conf = conf; + this.cryptoConf = conf.cryptoConf(); + this.secret = secret.toCharArray(); + this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf); + } + + /** + * Create the client challenge. + * + * @return A challenge to be sent the remote side. + */ + ClientChallenge challenge() throws GeneralSecurityException, IOException { + this.authNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + SecretKeySpec authKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), + authNonce, conf.encryptionKeyLength()); + initializeForAuth(conf.cipherTransformation(), authNonce, authKey); + + this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + return new ClientChallenge(new String(appId, UTF_8), + conf.keyFactoryAlgorithm(), + conf.keyFactoryIterations(), + conf.cipherTransformation(), + conf.encryptionKeyLength(), + authNonce, + challenge(appId, authNonce, challenge)); + } + + /** + * Validates the client challenge, and create the encryption backend for the channel from the + * parameters sent by the client. + * + * @param clientChallenge The challenge from the client. + * @return A response to be sent to the client. + */ + ServerResponse respond(ClientChallenge clientChallenge) + throws GeneralSecurityException, IOException { + + SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, + clientChallenge.nonce, clientChallenge.keyLength); + initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey); + + byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge); + byte[] response = challenge(appId, clientChallenge.nonce, rawResponse(challenge)); + byte[] sessionNonce = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); + byte[] inputIv = randomBytes(conf.ivLength()); + byte[] outputIv = randomBytes(conf.ivLength()); + + SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, + sessionNonce, clientChallenge.keyLength); + this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey, + inputIv, outputIv); + + // Note the IVs are swapped in the response. + return new ServerResponse(response, encrypt(sessionNonce), encrypt(outputIv), encrypt(inputIv)); + } + + /** + * Validates the server response and initializes the cipher to use for the session. + * + * @param serverResponse The response from the server. + */ + void validate(ServerResponse serverResponse) throws GeneralSecurityException { + byte[] response = validateChallenge(authNonce, serverResponse.response); + + byte[] expected = rawResponse(challenge); + Preconditions.checkArgument(Arrays.equals(expected, response)); + + byte[] nonce = decrypt(serverResponse.nonce); + byte[] inputIv = decrypt(serverResponse.inputIv); + byte[] outputIv = decrypt(serverResponse.outputIv); + + SecretKeySpec sessionKey = generateKey(conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), + nonce, conf.encryptionKeyLength()); + this.sessionCipher = new TransportCipher(cryptoConf, conf.cipherTransformation(), sessionKey, + inputIv, outputIv); + } + + TransportCipher sessionCipher() { + Preconditions.checkState(sessionCipher != null); + return sessionCipher; + } + + @Override + public void close() throws IOException { + // Close ciphers (by calling "doFinal()" with dummy data) and the random instance so that + // internal state is cleaned up. Error handling here is just for paranoia, and not meant to + // accurately report the errors when they happen. + RuntimeException error = null; + byte[] dummy = new byte[8]; + try { + doCipherOp(encryptor, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + try { + doCipherOp(decryptor, dummy, true); + } catch (Exception e) { + error = new RuntimeException(e); + } + random.close(); + + if (error != null) { + throw error; + } + } + + @VisibleForTesting + byte[] challenge(byte[] appId, byte[] nonce, byte[] challenge) throws GeneralSecurityException { + return encrypt(Bytes.concat(appId, nonce, challenge)); + } + + @VisibleForTesting + byte[] rawResponse(byte[] challenge) { + BigInteger orig = new BigInteger(challenge); + BigInteger response = orig.add(ONE); + return response.toByteArray(); + } + + private byte[] decrypt(byte[] in) throws GeneralSecurityException { + return doCipherOp(decryptor, in, false); + } + + private byte[] encrypt(byte[] in) throws GeneralSecurityException { + return doCipherOp(encryptor, in, false); + } + + private void initializeForAuth(String cipher, byte[] nonce, SecretKeySpec key) + throws GeneralSecurityException { + + // commons-crypto currently only supports ciphers that require an initial vector; so + // create a dummy vector so that we can initialize the ciphers. In the future, if + // different ciphers are supported, this will have to be configurable somehow. + byte[] iv = new byte[conf.ivLength()]; + System.arraycopy(nonce, 0, iv, 0, Math.min(nonce.length, iv.length)); + + encryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + encryptor.init(Cipher.ENCRYPT_MODE, key, new IvParameterSpec(iv)); + + decryptor = CryptoCipherFactory.getCryptoCipher(cipher, cryptoConf); + decryptor.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv)); + } + + /** + * Validates an encrypted challenge as defined in the protocol, and returns the byte array + * that corresponds to the actual challenge data. + */ + private byte[] validateChallenge(byte[] nonce, byte[] encryptedChallenge) + throws GeneralSecurityException { + + byte[] challenge = decrypt(encryptedChallenge); + checkSubArray(appId, challenge, 0); + checkSubArray(nonce, challenge, appId.length); + return Arrays.copyOfRange(challenge, appId.length + nonce.length, challenge.length); + } + + private SecretKeySpec generateKey(String kdf, int iterations, byte[] salt, int keyLength) + throws GeneralSecurityException { + + SecretKeyFactory factory = SecretKeyFactory.getInstance(kdf); + PBEKeySpec spec = new PBEKeySpec(secret, salt, iterations, keyLength); + + long start = System.nanoTime(); + SecretKey key = factory.generateSecret(spec); + long end = System.nanoTime(); + + LOG.debug("Generated key with {} iterations in {} us.", conf.keyFactoryIterations(), + (end - start) / 1000); + + return new SecretKeySpec(key.getEncoded(), conf.keyAlgorithm()); + } + + private byte[] doCipherOp(CryptoCipher cipher, byte[] in, boolean isFinal) + throws GeneralSecurityException { + + Preconditions.checkState(cipher != null); + + int scale = 1; + while (true) { + int size = in.length * scale; + byte[] buffer = new byte[size]; + try { + int outSize = isFinal ? cipher.doFinal(in, 0, in.length, buffer, 0) + : cipher.update(in, 0, in.length, buffer, 0); + if (outSize != buffer.length) { + byte[] output = new byte[outSize]; + System.arraycopy(buffer, 0, output, 0, output.length); + return output; + } else { + return buffer; + } + } catch (ShortBufferException e) { + // Try again with a bigger buffer. + scale *= 2; + } + } + } + + private byte[] randomBytes(int count) { + byte[] bytes = new byte[count]; + random.nextBytes(bytes); + return bytes; + } + + /** Checks that the "test" array is in the data array starting at the given offset. */ + private void checkSubArray(byte[] test, byte[] data, int offset) { + Preconditions.checkArgument(data.length >= test.length + offset); + for (int i = 0; i < test.length; i++) { + Preconditions.checkArgument(test[i] == data[i + offset]); + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java new file mode 100644 index 0000000000000..0a5c029940005 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Throwables; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.TransportConf; + +/** + * RPC Handler which performs authentication using Spark's auth protocol before delegating to a + * child RPC handler. If the configuration allows, this handler will delegate messages to a SASL + * RPC handler for further authentication, to support for clients that do not support Spark's + * protocol. + * + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + */ +class AuthRpcHandler extends RpcHandler { + private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class); + + /** Transport configuration. */ + private final TransportConf conf; + + /** The client channel. */ + private final Channel channel; + + /** + * RpcHandler we will delegate to for authenticated connections. When falling back to SASL + * this will be replaced with the SASL RPC handler. + */ + @VisibleForTesting + RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Whether auth is done and future calls should be delegated. */ + @VisibleForTesting + boolean doDelegate; + + AuthRpcHandler( + TransportConf conf, + Channel channel, + RpcHandler delegate, + SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.channel = channel; + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + } + + @Override + public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) { + if (doDelegate) { + delegate.receive(client, message, callback); + return; + } + + int position = message.position(); + int limit = message.limit(); + + ClientChallenge challenge; + try { + challenge = ClientChallenge.decodeMessage(message); + LOG.debug("Received new auth challenge for client {}.", channel.remoteAddress()); + } catch (RuntimeException e) { + if (conf.saslFallback()) { + LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.", + channel.remoteAddress()); + delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder); + message.position(position); + message.limit(limit); + delegate.receive(client, message, callback); + doDelegate = true; + } else { + LOG.debug("Unexpected challenge message from client {}, closing channel.", + channel.remoteAddress()); + callback.onFailure(new IllegalArgumentException("Unknown challenge message.")); + channel.close(); + } + return; + } + + // Here we have the client challenge, so perform the new auth protocol and set up the channel. + AuthEngine engine = null; + try { + engine = new AuthEngine(challenge.appId, secretKeyHolder.getSecretKey(challenge.appId), conf); + ServerResponse response = engine.respond(challenge); + ByteBuf responseData = Unpooled.buffer(response.encodedLength()); + response.encode(responseData); + callback.onSuccess(responseData.nioBuffer()); + engine.sessionCipher().addToChannel(channel); + } catch (Exception e) { + // This is a fatal error: authentication has failed. Close the channel explicitly. + LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress()); + callback.onFailure(new IllegalArgumentException("Authentication failed.")); + channel.close(); + return; + } finally { + if (engine != null) { + try { + engine.close(); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + } + + LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); + doDelegate = true; + } + + @Override + public void receive(TransportClient client, ByteBuffer message) { + delegate.receive(client, message); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { + delegate.channelInactive(client); + } + + @Override + public void exceptionCaught(Throwable cause, TransportClient client) { + delegate.exceptionCaught(cause, client); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java new file mode 100644 index 0000000000000..77a2a6af4d134 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthServerBootstrap.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.channel.Channel; + +import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * A bootstrap which is executed on a TransportServer's client channel once a client connects + * to the server, enabling authentication using Spark's auth protocol (and optionally SASL for + * clients that don't support the new protocol). + * + * It also automatically falls back to SASL if the new encryption backend is disabled, so that + * callers only need to install this bootstrap when authentication is enabled. + */ +public class AuthServerBootstrap implements TransportServerBootstrap { + + private final TransportConf conf; + private final SecretKeyHolder secretKeyHolder; + + public AuthServerBootstrap(TransportConf conf, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + } + + public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { + if (!conf.encryptionEnabled()) { + TransportServerBootstrap sasl = new SaslServerBootstrap(conf, secretKeyHolder); + return sasl.doBootstrap(channel, rpcHandler); + } + + return new AuthRpcHandler(conf, channel, rpcHandler, secretKeyHolder); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java new file mode 100644 index 0000000000000..819b8a7efbdba --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * The client challenge message, used to initiate authentication. + * + * Please see crypto/README.md for more details of implementation. + */ +public class ClientChallenge implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xFA; + + public final String appId; + public final String kdf; + public final int iterations; + public final String cipher; + public final int keyLength; + public final byte[] nonce; + public final byte[] challenge; + + public ClientChallenge( + String appId, + String kdf, + int iterations, + String cipher, + int keyLength, + byte[] nonce, + byte[] challenge) { + this.appId = appId; + this.kdf = kdf; + this.iterations = iterations; + this.cipher = cipher; + this.keyLength = keyLength; + this.nonce = nonce; + this.challenge = challenge; + } + + @Override + public int encodedLength() { + return 1 + 4 + 4 + + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(kdf) + + Encoders.Strings.encodedLength(cipher) + + Encoders.ByteArrays.encodedLength(nonce) + + Encoders.ByteArrays.encodedLength(challenge); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, kdf); + buf.writeInt(iterations); + Encoders.Strings.encode(buf, cipher); + buf.writeInt(keyLength); + Encoders.ByteArrays.encode(buf, nonce); + Encoders.ByteArrays.encode(buf, challenge); + } + + public static ClientChallenge decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalArgumentException("Expected ClientChallenge, received something else."); + } + + return new ClientChallenge( + Encoders.Strings.decode(buf), + Encoders.Strings.decode(buf), + buf.readInt(), + Encoders.Strings.decode(buf), + buf.readInt(), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf)); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md new file mode 100644 index 0000000000000..14df703270498 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/README.md @@ -0,0 +1,158 @@ +Spark Auth Protocol and AES Encryption Support +============================================== + +This file describes an auth protocol used by Spark as a more secure alternative to DIGEST-MD5. This +protocol is built on symmetric key encryption, based on the assumption that the two endpoints being +authenticated share a common secret, which is how Spark authentication currently works. The protocol +provides mutual authentication, meaning that after the negotiation both parties know that the remote +side knows the shared secret. The protocol is influenced by the ISO/IEC 9798 protocol, although it's +not an implementation of it. + +This protocol could be replaced with TLS PSK, except no PSK ciphers are available in the currently +released JREs. + +The protocol aims at solving the following shortcomings in Spark's current usage of DIGEST-MD5: + +- MD5 is an aging hash algorithm with known weaknesses, and a more secure alternative is desired. +- DIGEST-MD5 has a pre-defined set of ciphers for which it can generate keys. The only + viable, supported cipher these days is 3DES, and a more modern alternative is desired. +- Encrypting AES session keys with 3DES doesn't solve the issue, since the weakest link + in the negotiation would still be MD5 and 3DES. + +The protocol assumes that the shared secret is generated and distributed in a secure manner. + +The protocol always negotiates encryption keys. If encryption is not desired, the existing +SASL-based authentication, or no authentication at all, can be chosen instead. + +When messages are described below, it's expected that the implementation should support +arbitrary sizes for fields that don't have a fixed size. + +Client Challenge +---------------- + +The auth negotiation is started by the client. The client starts by generating an encryption +key based on the application's shared secret, and a nonce. + + KEY = KDF(SECRET, SALT, KEY_LENGTH) + +Where: +- KDF(): a key derivation function that takes a secret, a salt, a configurable number of + iterations, and a configurable key length. +- SALT: a byte sequence used to salt the key derivation function. +- KEY_LENGTH: length of the encryption key to generate. + + +The client generates a message with the following content: + + CLIENT_CHALLENGE = ( + APP_ID, + KDF, + ITERATIONS, + CIPHER, + KEY_LENGTH, + ANONCE, + ENC(APP_ID || ANONCE || CHALLENGE)) + +Where: + +- APP_ID: the application ID which the server uses to identify the shared secret. +- KDF: the key derivation function described above. +- ITERATIONS: number of iterations to run the KDF when generating keys. +- CIPHER: the cipher used to encrypt data. +- KEY_LENGTH: length of the encryption keys to generate, in bits. +- ANONCE: the nonce used as the salt when generating the auth key. +- ENC(): an encryption function that uses the cipher and the generated key. This function + will also be used in the definition of other messages below. +- CHALLENGE: a byte sequence used as a challenge to the server. +- ||: concatenation operator. + +When strings are used where byte arrays are expected, the UTF-8 representation of the string +is assumed. + +To respond to the challenge, the server should consider the byte array as representing an +arbitrary-length integer, and respond with the value of the integer plus one. + + +Server Response And Challenge +----------------------------- + +Once the client challenge is received, the server will generate the same auth key by +using the same algorithm the client has used. It will then verify the client challenge: +if the APP_ID and ANONCE fields match, the server knows that the client has the shared +secret. The server then creates a response to the client challenge, to prove that it also +has the secret key, and provides parameters to be used when creating the session key. + +The following describes the response from the server: + + SERVER_CHALLENGE = ( + ENC(APP_ID || ANONCE || RESPONSE), + ENC(SNONCE), + ENC(INIV), + ENC(OUTIV)) + +Where: + +- RESPONSE: the server's response to the client challenge. +- SNONCE: a nonce to be used as salt when generating the session key. +- INIV: initialization vector used to initialize the input channel of the client. +- OUTIV: initialization vector used to initialize the output channel of the client. + +At this point the server considers the client to be authenticated, and will try to +decrypt any data further sent by the client using the session key. + + +Default Algorithms +------------------ + +Configuration options are available for the KDF and cipher algorithms to use. + +The default KDF is "PBKDF2WithHmacSHA1". Users should be able to select any algorithm +from those supported by the `javax.crypto.SecretKeyFactory` class, as long as they support +PBEKeySpec when generating keys. The default number of iterations was chosen to take a +reasonable amount of time on modern CPUs. See the documentation in TransportConf for more +details. + +The default cipher algorithm is "AES/CTR/NoPadding". Users should be able to select any +algorithm supported by the commons-crypto library. It should allow the cipher to operate +in stream mode. + +The default key length is 128 (bits). + + +Implementation Details +---------------------- + +The commons-crypto library currently only supports AES ciphers, and requires an initialization +vector (IV). This first version of the protocol does not explicitly include the IV in the client +challenge message. Instead, the IV should be derived from the nonce, including the needed bytes, and +padding the IV with zeroes in case the nonce is not long enough. + +Future versions of the protocol might add support for new ciphers and explicitly include needed +configuration parameters in the messages. + + +Threat Assessment +----------------- + +The protocol is secure against different forms of attack: + +* Eavesdropping: the protocol is built on the assumption that it's computationally infeasible + to calculate the original secret from the encrypted messages. Neither the secret nor any + encryption keys are transmitted on the wire, encrypted or not. + +* Man-in-the-middle: because the protocol performs mutual authentication, both ends need to + know the shared secret to be able to decrypt session data. Even if an attacker is able to insert a + malicious "proxy" between endpoints, the attacker won't be able to read any of the data exchanged + between client and server, nor insert arbitrary commands for the server to execute. + +* Replay attacks: the use of nonces when generating keys prevents an attacker from being able to + just replay messages sniffed from the communication channel. + +An attacker may replay the client challenge and successfully "prove" to a server that it "knows" the +shared secret. But the attacker won't be able to decrypt the server's response, and thus won't be +able to generate a session key, which will make it hard to craft a valid, encrypted message that the +server will be able to understand. This will cause the server to close the connection as soon as the +attacker tries to send any command to the server. The attacker can just hold the channel open for +some time, which will be closed when the server times out the channel. These issues could be +separately mitigated by adding a shorter timeout for the first message after authentication, and +potentially by adding host blacklists if a possible attack is detected from a particular host. diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java new file mode 100644 index 0000000000000..caf3a0f3b38cc --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ServerResponse.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * Server's response to client's challenge. + * + * Please see crypto/README.md for more details. + */ +public class ServerResponse implements Encodable { + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xFB; + + public final byte[] response; + public final byte[] nonce; + public final byte[] inputIv; + public final byte[] outputIv; + + public ServerResponse( + byte[] response, + byte[] nonce, + byte[] inputIv, + byte[] outputIv) { + this.response = response; + this.nonce = nonce; + this.inputIv = inputIv; + this.outputIv = outputIv; + } + + @Override + public int encodedLength() { + return 1 + + Encoders.ByteArrays.encodedLength(response) + + Encoders.ByteArrays.encodedLength(nonce) + + Encoders.ByteArrays.encodedLength(inputIv) + + Encoders.ByteArrays.encodedLength(outputIv); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.ByteArrays.encode(buf, response); + Encoders.ByteArrays.encode(buf, nonce); + Encoders.ByteArrays.encode(buf, inputIv); + Encoders.ByteArrays.encode(buf, outputIv); + } + + public static ServerResponse decodeMessage(ByteBuffer buffer) { + ByteBuf buf = Unpooled.wrappedBuffer(buffer); + + if (buf.readByte() != TAG_BYTE) { + throw new IllegalArgumentException("Expected ServerResponse, received something else."); + } + + return new ServerResponse( + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf), + Encoders.ByteArrays.decode(buf)); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java new file mode 100644 index 0000000000000..7376d1ddc4818 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.AbstractReferenceCounted; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; + +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; + +/** + * Cipher for encryption and decryption. + */ +public class TransportCipher { + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; + private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; + private static final int STREAM_BUFFER_SIZE = 1024 * 32; + + private final Properties conf; + private final String cipher; + private final SecretKeySpec key; + private final byte[] inIv; + private final byte[] outIv; + + public TransportCipher( + Properties conf, + String cipher, + SecretKeySpec key, + byte[] inIv, + byte[] outIv) { + this.conf = conf; + this.cipher = cipher; + this.key = key; + this.inIv = inIv; + this.outIv = outIv; + } + + public String getCipherTransformation() { + return cipher; + } + + @VisibleForTesting + SecretKeySpec getKey() { + return key; + } + + /** The IV for the input channel (i.e. output channel of the remote side). */ + public byte[] getInputIv() { + return inIv; + } + + /** The IV for the output channel (i.e. input channel of the remote side). */ + public byte[] getOutputIv() { + return outIv; + } + + private CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); + } + + private CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv)); + } + + /** + * Add handlers to channel. + * + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); + } + + private static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteChannel; + private final CryptoOutputStream cos; + + EncryptionHandler(TransportCipher cipher) throws IOException { + byteChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteChannel); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(new EncryptedMessage(cos, msg, byteChannel), promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + cos.close(); + } finally { + super.close(ctx, promise); + } + } + } + + private static class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + + DecryptionHandler(TransportCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + byteChannel.feedData((ByteBuf) data); + + byte[] decryptedData = new byte[byteChannel.readableBytes()]; + int offset = 0; + while (offset < decryptedData.length) { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + cis.close(); + } finally { + super.channelInactive(ctx); + } + } + } + + private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private long transferred; + private CryptoOutputStream cos; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private ByteArrayWritableChannel byteEncChannel; + private ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage(CryptoOutputStream cos, Object msg, ByteArrayWritableChannel ch) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + this.cos = cos; + this.byteEncChannel = ch; + } + + @Override + public long count() { + return isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return transferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transfered(), "Invalid position."); + + do { + if (currentEncrypted == null) { + encryptMore(); + } + + int bytesWritten = currentEncrypted.remaining(); + target.write(currentEncrypted); + bytesWritten -= currentEncrypted.remaining(); + transferred += bytesWritten; + if (!currentEncrypted.hasRemaining()) { + currentEncrypted = null; + byteEncChannel.reset(); + } + } while (transferred < count()); + + return transferred; + } + + private void encryptMore() throws IOException { + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transfered()); + } + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index f0956438ade24..39a7495828a8a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -35,6 +35,10 @@ public final class MessageDecoder extends MessageToMessageDecoder { private static final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + public static final MessageDecoder INSTANCE = new MessageDecoder(); + + private MessageDecoder() {} + @Override public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { Message.Type msgType = Message.Type.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 276f16637efc9..997f74e1a21b4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -35,6 +35,10 @@ public final class MessageEncoder extends MessageToMessageEncoder { private static final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + public static final MessageEncoder INSTANCE = new MessageEncoder(); + + private MessageEncoder() {} + /*** * Encodes a Message by invoking its encode() method. For non-data messages, we will add one * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 9e5c616ee5a1f..647813772294e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -40,24 +40,14 @@ public class SaslClientBootstrap implements TransportClientBootstrap { private static final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); - private final boolean encrypt; private final TransportConf conf; private final String appId; private final SecretKeyHolder secretKeyHolder; public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { - this(conf, appId, secretKeyHolder, false); - } - - public SaslClientBootstrap( - TransportConf conf, - String appId, - SecretKeyHolder secretKeyHolder, - boolean encrypt) { this.conf = conf; this.appId = appId; this.secretKeyHolder = secretKeyHolder; - this.encrypt = encrypt; } /** @@ -67,7 +57,7 @@ public SaslClientBootstrap( */ @Override public void doBootstrap(TransportClient client, Channel channel) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt); + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); try { byte[] payload = saslClient.firstToken(); @@ -77,20 +67,21 @@ public void doBootstrap(TransportClient client, Channel channel) { msg.encode(buf); buf.writeBytes(msg.body().nioByteBuffer()); - ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs()); + ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.authRTTimeoutMs()); payload = saslClient.response(JavaUtils.bufferToArray(response)); } client.setClientId(appId); - if (encrypt) { + if (conf.saslEncryption()) { if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslClient.getNegotiatedProperty(Sasl.QOP))) { throw new RuntimeException( new SaslException("Encryption requests by negotiated non-encrypted connection.")); } + SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize()); saslClient = null; - logger.debug("Channel {} configured for SASL encryption.", client); + logger.debug("Channel {} configured for encryption.", client); } } catch (IOException ioe) { throw new RuntimeException(ioe); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c41f5b6873f6c..0231428318add 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -42,7 +42,7 @@ * Note that the authentication process consists of multiple challenge-response pairs, each of * which are individual RPCs. */ -class SaslRpcHandler extends RpcHandler { +public class SaslRpcHandler extends RpcHandler { private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); /** Transport configuration. */ @@ -59,8 +59,9 @@ class SaslRpcHandler extends RpcHandler { private SparkSaslServer saslServer; private boolean isComplete; + private boolean isAuthenticated; - SaslRpcHandler( + public SaslRpcHandler( TransportConf conf, Channel channel, RpcHandler delegate, @@ -71,6 +72,7 @@ class SaslRpcHandler extends RpcHandler { this.secretKeyHolder = secretKeyHolder; this.saslServer = null; this.isComplete = false; + this.isAuthenticated = false; } @Override @@ -80,30 +82,31 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb delegate.receive(client, message, callback); return; } + if (saslServer == null || !saslServer.isComplete()) { + ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); + SaslMessage saslMessage; + try { + saslMessage = SaslMessage.decode(nettyBuf); + } finally { + nettyBuf.release(); + } - ByteBuf nettyBuf = Unpooled.wrappedBuffer(message); - SaslMessage saslMessage; - try { - saslMessage = SaslMessage.decode(nettyBuf); - } finally { - nettyBuf.release(); - } - - if (saslServer == null) { - // First message in the handshake, setup the necessary state. - client.setClientId(saslMessage.appId); - saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); - } + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + client.setClientId(saslMessage.appId); + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, + conf.saslServerAlwaysEncrypt()); + } - byte[] response; - try { - response = saslServer.response(JavaUtils.bufferToArray( - saslMessage.body().nioByteBuffer())); - } catch (IOException ioe) { - throw new RuntimeException(ioe); + byte[] response; + try { + response = saslServer.response(JavaUtils.bufferToArray( + saslMessage.body().nioByteBuffer())); + } catch (IOException ioe) { + throw new RuntimeException(ioe); + } + callback.onSuccess(ByteBuffer.wrap(response)); } - callback.onSuccess(ByteBuffer.wrap(response)); // Setup encryption after the SASL response is sent, otherwise the client can't parse the // response. It's ok to change the channel pipeline here since we are processing an incoming @@ -111,16 +114,16 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { - logger.debug("SASL authentication successful for channel {}", client); - isComplete = true; - if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { - logger.debug("Enabling encryption for channel {}", client); - SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); - saslServer = null; - } else { - saslServer.dispose(); - saslServer = null; + if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { + logger.debug("SASL authentication successful for channel {}", client); + complete(true); + return; } + + logger.debug("Enabling encryption for channel {}", client); + SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize()); + complete(false); + return; } } @@ -155,4 +158,17 @@ public void exceptionCaught(Throwable cause, TransportClient client) { delegate.exceptionCaught(cause, client); } + private void complete(boolean dispose) { + if (dispose) { + try { + saslServer.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL server", e); + } + } + + saslServer = null; + isComplete = true; + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index c33848c8406c1..56782a8327876 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -18,7 +18,7 @@ package org.apache.spark.network.server; import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.handler.timeout.IdleState; import io.netty.handler.timeout.IdleStateEvent; import org.slf4j.Logger; @@ -26,7 +26,6 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportResponseHandler; -import org.apache.spark.network.protocol.Message; import org.apache.spark.network.protocol.RequestMessage; import org.apache.spark.network.protocol.ResponseMessage; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -48,7 +47,7 @@ * on the channel for at least `requestTimeoutMs`. Note that this is duplex traffic; we will not * timeout if the client is continuously sending but getting no responses, for simplicity. */ -public class TransportChannelHandler extends SimpleChannelInboundHandler { +public class TransportChannelHandler extends ChannelInboundHandlerAdapter { private static final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); private final TransportClient client; @@ -88,14 +87,14 @@ public void channelActive(ChannelHandlerContext ctx) throws Exception { try { requestHandler.channelActive(); } catch (RuntimeException e) { - logger.error("Exception from request handler while registering channel", e); + logger.error("Exception from request handler while channel is active", e); } try { responseHandler.channelActive(); } catch (RuntimeException e) { - logger.error("Exception from response handler while registering channel", e); + logger.error("Exception from response handler while channel is active", e); } - super.channelRegistered(ctx); + super.channelActive(ctx); } @Override @@ -103,22 +102,24 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { try { requestHandler.channelInactive(); } catch (RuntimeException e) { - logger.error("Exception from request handler while unregistering channel", e); + logger.error("Exception from request handler while channel is inactive", e); } try { responseHandler.channelInactive(); } catch (RuntimeException e) { - logger.error("Exception from response handler while unregistering channel", e); + logger.error("Exception from response handler while channel is inactive", e); } - super.channelUnregistered(ctx); + super.channelInactive(ctx); } @Override - public void channelRead0(ChannelHandlerContext ctx, Message request) throws Exception { + public void channelRead(ChannelHandlerContext ctx, Object request) throws Exception { if (request instanceof RequestMessage) { requestHandler.handle((RequestMessage) request); - } else { + } else if (request instanceof ResponseMessage) { responseHandler.handle((ResponseMessage) request); + } else { + ctx.fireChannelRead(request); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 900e8eb255407..8193bc1376102 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -22,8 +22,6 @@ import com.google.common.base.Throwables; import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -189,21 +187,16 @@ private void processOneWayMessage(OneWayMessage req) { * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. */ - private void respond(final Encodable result) { - final SocketAddress remoteAddress = channel.remoteAddress(); - channel.writeAndFlush(result).addListener( - new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) throws Exception { - if (future.isSuccess()) { - logger.trace("Sent result {} to client {}", result, remoteAddress); - } else { - logger.error(String.format("Error sending result %s to %s; closing connection", - result, remoteAddress), future.cause()); - channel.close(); - } - } + private void respond(Encodable result) { + SocketAddress remoteAddress = channel.remoteAddress(); + channel.writeAndFlush(result).addListener(future -> { + if (future.isSuccess()) { + logger.trace("Sent result {} to client {}", result, remoteAddress); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + channel.close(); } - ); + }); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java index 0d7a677820d35..047c5f3f1f094 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -89,7 +89,7 @@ private void init(String hostToBind, int portToBind) { IOMode ioMode = IOMode.valueOf(conf.ioMode()); EventLoopGroup bossGroup = - NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), conf.getModuleName() + "-server"); EventLoopGroup workerGroup = bossGroup; PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator( diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java new file mode 100644 index 0000000000000..25d103d0e316f --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteArrayReadableChannel.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +import io.netty.buffer.ByteBuf; + +public class ByteArrayReadableChannel implements ReadableByteChannel { + private ByteBuf data; + + public int readableBytes() { + return data.readableBytes(); + } + + public void feedData(ByteBuf buf) { + data = buf; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + int totalRead = 0; + while (data.readableBytes() > 0 && dst.remaining() > 0) { + int bytesToRead = Math.min(data.readableBytes(), dst.remaining()); + dst.put(data.readSlice(bytesToRead).nioBuffer()); + totalRead += bytesToRead; + } + + if (data.readableBytes() == 0) { + data.release(); + } + + return totalRead; + } + + @Override + public void close() throws IOException { + } + + @Override + public boolean isOpen() { + return true; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java index d944d9da1c7f8..f6aef499b2bfe 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -17,6 +17,7 @@ package org.apache.spark.network.util; +import java.util.Map; import java.util.NoSuchElementException; /** @@ -26,6 +27,9 @@ public abstract class ConfigProvider { /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ public abstract String get(String name); + /** Returns all the config values in the provider. */ + public abstract Iterable> getAll(); + public String get(String name, String defaultValue) { try { return get(name); @@ -49,4 +53,5 @@ public double getDouble(String name, double defaultValue) { public boolean getBoolean(String name, boolean defaultValue) { return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/CryptoUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/CryptoUtils.java new file mode 100644 index 0000000000000..a6d8358ee9004 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/CryptoUtils.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.Map; +import java.util.Properties; + +/** + * Utility methods related to the commons-crypto library. + */ +public class CryptoUtils { + + // The prefix for the configurations passing to Apache Commons Crypto library. + public static final String COMMONS_CRYPTO_CONFIG_PREFIX = "commons.crypto."; + + /** + * Extract the commons-crypto configuration embedded in a list of config values. + * + * @param prefix Prefix in the given configuration that identifies the commons-crypto configs. + * @param conf List of configuration values. + */ + public static Properties toCryptoConf(String prefix, Iterable> conf) { + Properties props = new Properties(); + for (Map.Entry e : conf) { + String key = e.getKey(); + if (key.startsWith(prefix)) { + props.setProperty(COMMONS_CRYPTO_CONFIG_PREFIX + key.substring(prefix.length()), + e.getValue()); + } + } + return props; + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java index f3eaf22c0166e..afc59efaef810 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -18,10 +18,13 @@ package org.apache.spark.network.util; import java.io.Closeable; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; import java.nio.charset.StandardCharsets; +import java.util.Locale; import java.util.concurrent.TimeUnit; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -208,7 +211,7 @@ private static boolean isSymlink(File file) throws IOException { * The unit is also considered the default if the given string does not specify a unit. */ public static long timeStringAs(String str, TimeUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("(-?[0-9]+)([a-z]+)?").matcher(lower); @@ -256,7 +259,7 @@ public static long timeStringAsSec(String str) { * provided, a direct conversion to the provided unit is attempted. */ public static long byteStringAs(String str, ByteUnit unit) { - String lower = str.toLowerCase().trim(); + String lower = str.toLowerCase(Locale.ROOT).trim(); try { Matcher m = Pattern.compile("([0-9]+)([a-z]+)?").matcher(lower); @@ -344,4 +347,17 @@ public static byte[] bufferToArray(ByteBuffer buffer) { } } + /** + * Fills a buffer with data read from the channel. + */ + public static void readFully(ReadableByteChannel channel, ByteBuffer dst) throws IOException { + int expected = dst.remaining(); + while (dst.hasRemaining()) { + if (channel.read(dst) < 0) { + throw new EOFException(String.format("Not enough bytes in channel (expected %d).", + expected)); + } + } + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java index 668d2356b955d..a2cf87d1af7ed 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/MapConfigProvider.java @@ -17,17 +17,20 @@ package org.apache.spark.network.util; -import com.google.common.collect.Maps; - +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.NoSuchElementException; /** ConfigProvider based on a Map (copied in the constructor). */ public class MapConfigProvider extends ConfigProvider { + + public static final MapConfigProvider EMPTY = new MapConfigProvider(Collections.emptyMap()); + private final Map config; public MapConfigProvider(Map config) { - this.config = Maps.newHashMap(config); + this.config = new HashMap<>(config); } @Override @@ -38,4 +41,16 @@ public String get(String name) { } return value; } + + @Override + public String get(String name, String defaultValue) { + String value = config.get(name); + return value == null ? defaultValue : value; + } + + @Override + public Iterable> getAll() { + return config.entrySet(); + } + } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 64eaba103cccb..a25078e262efb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -17,6 +17,9 @@ package org.apache.spark.network.util; +import java.util.Locale; +import java.util.Properties; + import com.google.common.primitives.Ints; /** @@ -24,11 +27,6 @@ */ public class TransportConf { - static { - // Set this due to Netty PR #5661 for Netty 4.0.37+ to work - System.setProperty("io.netty.maxDirectMemory", "0"); - } - private final String SPARK_NETWORK_IO_MODE_KEY; private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY; private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY; @@ -73,8 +71,14 @@ private String getConfKey(String suffix) { return "spark." + module + "." + suffix; } + public String getModuleName() { + return module; + } + /** IO mode: nio or epoll */ - public String ioMode() { return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(); } + public String ioMode() { + return conf.get(SPARK_NETWORK_IO_MODE_KEY, "NIO").toUpperCase(Locale.ROOT); + } /** If true, we will prefer allocating off-heap byte buffers within Netty. */ public boolean preferDirectBufs() { @@ -116,9 +120,10 @@ public int numConnectionsPerPeer() { /** Send buffer size (SO_SNDBUF). */ public int sendBuf() { return conf.getInt(SPARK_NETWORK_IO_SENDBUFFER_KEY, -1); } - /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ - public int saslRTTimeoutMs() { - return (int) JavaUtils.timeStringAsSec(conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s")) * 1000; + /** Timeout for a single round trip of auth message exchange, in milliseconds. */ + public int authRTTimeoutMs() { + return (int) JavaUtils.timeStringAsSec(conf.get("spark.network.auth.rpcTimeout", + conf.get(SPARK_NETWORK_SASL_TIMEOUT_KEY, "30s"))) * 1000; } /** @@ -161,7 +166,77 @@ public int portMaxRetries() { } /** - * Maximum number of bytes to be encrypted at a time when SASL encryption is enabled. + * Enables strong encryption. Also enables the new auth protocol, used to negotiate keys. + */ + public boolean encryptionEnabled() { + return conf.getBoolean("spark.network.crypto.enabled", false); + } + + /** + * The cipher transformation to use for encrypting session data. + */ + public String cipherTransformation() { + return conf.get("spark.network.crypto.cipher", "AES/CTR/NoPadding"); + } + + /** + * The key generation algorithm. This should be an algorithm that accepts a "PBEKeySpec" + * as input. The default value (PBKDF2WithHmacSHA1) is available in Java 7. + */ + public String keyFactoryAlgorithm() { + return conf.get("spark.network.crypto.keyFactoryAlgorithm", "PBKDF2WithHmacSHA1"); + } + + /** + * How many iterations to run when generating keys. + * + * See some discussion about this at: http://security.stackexchange.com/q/3959 + * The default value was picked for speed, since it assumes that the secret has good entropy + * (128 bits by default), which is not generally the case with user passwords. + */ + public int keyFactoryIterations() { + return conf.getInt("spark.networy.crypto.keyFactoryIterations", 1024); + } + + /** + * Encryption key length, in bits. + */ + public int encryptionKeyLength() { + return conf.getInt("spark.network.crypto.keyLength", 128); + } + + /** + * Initial vector length, in bytes. + */ + public int ivLength() { + return conf.getInt("spark.network.crypto.ivLength", 16); + } + + /** + * The algorithm for generated secret keys. Nobody should really need to change this, + * but configurable just in case. + */ + public String keyAlgorithm() { + return conf.get("spark.network.crypto.keyAlgorithm", "AES"); + } + + /** + * Whether to fall back to SASL if the new auth protocol fails. Enabled by default for + * backwards compatibility. + */ + public boolean saslFallback() { + return conf.getBoolean("spark.network.crypto.saslFallback", true); + } + + /** + * Whether to enable SASL-based encryption when authenticating using SASL. + */ + public boolean saslEncryption() { + return conf.getBoolean("spark.authenticate.enableSaslEncryption", false); + } + + /** + * Maximum number of bytes to be encrypted at a time when SASL encryption is used. */ public int maxSaslEncryptedBlockSize() { return Ints.checkedCast(JavaUtils.byteStringAsBytes( @@ -175,4 +250,11 @@ public boolean saslServerAlwaysEncrypt() { return conf.getBoolean("spark.network.sasl.serverAlwaysEncrypt", false); } + /** + * The commons-crypto configuration for the module. + */ + public Properties cryptoConf() { + return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll()); + } + } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 6d62eaf35d8cc..824482af08dd4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.RandomAccessFile; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; @@ -29,7 +30,6 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.google.common.io.Closeables; import org.junit.AfterClass; @@ -48,7 +48,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ChunkFetchIntegrationSuite { @@ -87,7 +87,7 @@ public static void setUp() throws Exception { Closeables.close(fp, shouldSuppressIOException); } - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); streamManager = new StreamManager() { @@ -179,49 +179,49 @@ public void onFailure(int chunkIndex, Throwable e) { @Test public void fetchBufferChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX)); + assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } @Test public void fetchFileChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + FetchResult res = fetchChunks(Arrays.asList(FILE_CHUNK_INDEX)); + assertEquals(Sets.newHashSet(FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + assertBufferListsEqual(Arrays.asList(fileChunk), res.buffers); res.releaseBuffers(); } @Test public void fetchNonExistentChunk() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(12345)); + FetchResult res = fetchChunks(Arrays.asList(12345)); assertTrue(res.successChunks.isEmpty()); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertEquals(Sets.newHashSet(12345), res.failedChunks); assertTrue(res.buffers.isEmpty()); } @Test public void fetchBothChunks() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX), res.successChunks); assertTrue(res.failedChunks.isEmpty()); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + assertBufferListsEqual(Arrays.asList(bufferChunk, fileChunk), res.buffers); res.releaseBuffers(); } @Test public void fetchChunkAndNonExistent() throws Exception { - FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); - assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); - assertEquals(res.failedChunks, Sets.newHashSet(12345)); - assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + FetchResult res = fetchChunks(Arrays.asList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(Sets.newHashSet(BUFFER_CHUNK_INDEX), res.successChunks); + assertEquals(Sets.newHashSet(12345), res.failedChunks); + assertBufferListsEqual(Arrays.asList(bufferChunk), res.buffers); res.releaseBuffers(); } - private void assertBufferListsEqual(List list0, List list1) + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); for (int i = 0; i < list0.size(); i ++) { @@ -229,7 +229,8 @@ private void assertBufferListsEqual(List list0, List configMap = Maps.newHashMap(); + Map configMap = new HashMap<>(); configMap.put("spark.shuffle.io.connectionTimeout", "10s"); conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); @@ -226,6 +225,8 @@ public StreamManager getStreamManager() { callback0.latch.await(60, TimeUnit.SECONDS); assertTrue(callback0.failure instanceof IOException); + // make sure callback1 is called. + callback1.latch.await(60, TimeUnit.SECONDS); // failed at same time as previous assertTrue(callback1.failure instanceof IOException); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index a7a99f3bfc707..8ff737b129641 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -42,7 +42,7 @@ import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class RpcIntegrationSuite { @@ -53,7 +53,7 @@ public class RpcIntegrationSuite { @BeforeClass public static void setUp() throws Exception { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); rpcHandler = new RpcHandler() { @Override public void receive( diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index 9c49556927f0b..f253a07e64be1 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -47,7 +47,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportServer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class StreamSuite { @@ -91,7 +91,7 @@ public static void setUp() throws Exception { fp.close(); } - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 44d16d54225e7..e95d25fe6ae91 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -19,19 +19,20 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; -import com.google.common.collect.Maps; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertTrue; import org.apache.spark.network.client.TransportClient; @@ -40,9 +41,8 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.SystemPropertyConfigProvider; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; public class TransportClientFactorySuite { @@ -53,7 +53,7 @@ public class TransportClientFactorySuite { @Before public void setUp() { - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); RpcHandler rpcHandler = new NoOpRpcHandler(); context = new TransportContext(conf, rpcHandler); server1 = context.createServer(); @@ -72,37 +72,36 @@ public void tearDown() { * * If concurrent is true, create multiple threads to create clients in parallel. */ - private void testClientReuse(final int maxConnections, boolean concurrent) + private void testClientReuse(int maxConnections, boolean concurrent) throws IOException, InterruptedException { - Map configMap = Maps.newHashMap(); + Map configMap = new HashMap<>(); configMap.put("spark.shuffle.io.numConnectionsPerPeer", Integer.toString(maxConnections)); TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); RpcHandler rpcHandler = new NoOpRpcHandler(); TransportContext context = new TransportContext(conf, rpcHandler); - final TransportClientFactory factory = context.createClientFactory(); - final Set clients = Collections.synchronizedSet( + TransportClientFactory factory = context.createClientFactory(); + Set clients = Collections.synchronizedSet( new HashSet()); - final AtomicInteger failed = new AtomicInteger(); + AtomicInteger failed = new AtomicInteger(); Thread[] attempts = new Thread[maxConnections * 10]; // Launch a bunch of threads to create new clients. for (int i = 0; i < attempts.length; i++) { - attempts[i] = new Thread() { - @Override - public void run() { - try { - TransportClient client = - factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertTrue(client.isActive()); - clients.add(client); - } catch (IOException e) { - failed.incrementAndGet(); - } + attempts[i] = new Thread(() -> { + try { + TransportClient client = + factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(client.isActive()); + clients.add(client); + } catch (IOException e) { + failed.incrementAndGet(); + } catch (InterruptedException e) { + throw new RuntimeException(e); } - }; + }); if (concurrent) { attempts[i].start(); @@ -112,8 +111,8 @@ public void run() { } // Wait until all the threads complete. - for (int i = 0; i < attempts.length; i++) { - attempts[i].join(); + for (Thread attempt : attempts) { + attempt.join(); } Assert.assertEquals(0, failed.get()); @@ -143,13 +142,13 @@ public void reuseClientsUpToConfigVariableConcurrent() throws Exception { } @Test - public void returnDifferentClientsForDifferentServers() throws IOException { + public void returnDifferentClientsForDifferentServers() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); assertTrue(c1.isActive()); assertTrue(c2.isActive()); - assertTrue(c1 != c2); + assertNotSame(c1, c2); factory.close(); } @@ -166,13 +165,13 @@ public void neverReturnInactiveClients() throws IOException, InterruptedExceptio assertFalse(c1.isActive()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); - assertFalse(c1 == c2); + assertNotSame(c1, c2); assertTrue(c2.isActive()); factory.close(); } @Test - public void closeBlockClientsWithFactory() throws IOException { + public void closeBlockClientsWithFactory() throws IOException, InterruptedException { TransportClientFactory factory = context.createClientFactory(); TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); @@ -199,10 +198,14 @@ public String get(String name) { } return value; } + + @Override + public Iterable> getAll() { + throw new UnsupportedOperationException(); + } }); TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); - TransportClientFactory factory = context.createClientFactory(); - try { + try (TransportClientFactory factory = context.createClientFactory()) { TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); assertTrue(c1.isActive()); long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds @@ -210,8 +213,6 @@ public String get(String name) { Thread.sleep(10); } assertFalse(c1.isActive()); - } finally { - factory.close(); } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 128f7cba74350..09fc80d12d510 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -24,11 +24,8 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; import static org.mockito.Mockito.*; -import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; @@ -54,7 +51,7 @@ public void handleSuccessfulFetch() throws Exception { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onSuccess(eq(0), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -67,7 +64,7 @@ public void handleFailedFetch() throws Exception { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); - verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(0), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -84,9 +81,9 @@ public void clearAllOutstandingRequests() throws Exception { handler.exceptionCaught(new Exception("duh duh duhhhh")); // should fail both b2 and b3 - verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); - verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); - verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + verify(callback, times(1)).onSuccess(eq(0), any()); + verify(callback, times(1)).onFailure(eq(1), any()); + verify(callback, times(1)).onFailure(eq(2), any()); assertEquals(0, handler.numOutstandingRequests()); } @@ -118,7 +115,7 @@ public void handleFailedRPC() throws Exception { assertEquals(1, handler.numOutstandingRequests()); handler.handle(new RpcFailure(12345, "oh no")); - verify(callback, times(1)).onFailure((Throwable) any()); + verify(callback, times(1)).onFailure(any()); assertEquals(0, handler.numOutstandingRequests()); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java new file mode 100644 index 0000000000000..a3519fe4a423e --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.util.Arrays; +import static java.nio.charset.StandardCharsets.UTF_8; + +import org.junit.BeforeClass; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class AuthEngineSuite { + + private static TransportConf conf; + + @BeforeClass + public static void setUp() { + conf = new TransportConf("rpc", MapConfigProvider.EMPTY); + } + + @Test + public void testAuthEngine() throws Exception { + AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf); + + try { + ClientChallenge clientChallenge = client.challenge(); + ServerResponse serverResponse = server.respond(clientChallenge); + client.validate(serverResponse); + + TransportCipher serverCipher = server.sessionCipher(); + TransportCipher clientCipher = client.sessionCipher(); + + assertTrue(Arrays.equals(serverCipher.getInputIv(), clientCipher.getOutputIv())); + assertTrue(Arrays.equals(serverCipher.getOutputIv(), clientCipher.getInputIv())); + assertEquals(serverCipher.getKey(), clientCipher.getKey()); + } finally { + client.close(); + server.close(); + } + } + + @Test + public void testMismatchedSecret() throws Exception { + AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "different_secret", conf); + + ClientChallenge clientChallenge = client.challenge(); + try { + server.respond(clientChallenge); + fail("Should have failed to validate response."); + } catch (IllegalArgumentException e) { + // Expected. + } + } + + @Test(expected = IllegalArgumentException.class) + public void testWrongAppId() throws Exception { + AuthEngine engine = new AuthEngine("appId", "secret", conf); + ClientChallenge challenge = engine.challenge(); + + byte[] badChallenge = engine.challenge(new byte[] { 0x00 }, challenge.nonce, + engine.rawResponse(engine.challenge)); + engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, + challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + } + + @Test(expected = IllegalArgumentException.class) + public void testWrongNonce() throws Exception { + AuthEngine engine = new AuthEngine("appId", "secret", conf); + ClientChallenge challenge = engine.challenge(); + + byte[] badChallenge = engine.challenge(challenge.appId.getBytes(UTF_8), new byte[] { 0x00 }, + engine.rawResponse(engine.challenge)); + engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, + challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + } + + @Test(expected = IllegalArgumentException.class) + public void testBadChallenge() throws Exception { + AuthEngine engine = new AuthEngine("appId", "secret", conf); + ClientChallenge challenge = engine.challenge(); + + byte[] badChallenge = new byte[challenge.challenge.length]; + engine.respond(new ClientChallenge(challenge.appId, challenge.kdf, challenge.iterations, + challenge.cipher, challenge.keyLength, challenge.nonce, badChallenge)); + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java new file mode 100644 index 0000000000000..8751944a1c2a3 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import io.netty.channel.Channel; +import org.junit.After; +import org.junit.Test; +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.TransportServerBootstrap; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class AuthIntegrationSuite { + + private AuthTestCtx ctx; + + @After + public void cleanUp() throws Exception { + if (ctx != null) { + ctx.close(); + } + ctx = null; + } + + @Test + public void testNewAuth() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("secret"); + ctx.createClient("secret"); + + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + assertTrue(ctx.authRpcHandler.doDelegate); + assertFalse(ctx.authRpcHandler.delegate instanceof SaslRpcHandler); + } + + @Test + public void testAuthFailure() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("server"); + + try { + ctx.createClient("client"); + fail("Should have failed to create client."); + } catch (Exception e) { + assertFalse(ctx.authRpcHandler.doDelegate); + assertFalse(ctx.serverChannel.isActive()); + } + } + + @Test + public void testSaslServerFallback() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("secret", true); + ctx.createClient("secret", false); + + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + } + + @Test + public void testSaslClientFallback() throws Exception { + ctx = new AuthTestCtx(); + ctx.createServer("secret", false); + ctx.createClient("secret", true); + + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + } + + @Test + public void testAuthReplay() throws Exception { + // This test covers the case where an attacker replays a challenge message sniffed from the + // network, but doesn't know the actual secret. The server should close the connection as + // soon as a message is sent after authentication is performed. This is emulated by removing + // the client encryption handler after authentication. + ctx = new AuthTestCtx(); + ctx.createServer("secret"); + ctx.createClient("secret"); + + assertNotNull(ctx.client.getChannel().pipeline() + .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); + + try { + ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + fail("Should have failed unencrypted RPC."); + } catch (Exception e) { + assertTrue(ctx.authRpcHandler.doDelegate); + } + } + + private class AuthTestCtx { + + private final String appId = "testAppId"; + private final TransportConf conf; + private final TransportContext ctx; + + TransportClient client; + TransportServer server; + volatile Channel serverChannel; + volatile AuthRpcHandler authRpcHandler; + + AuthTestCtx() throws Exception { + Map testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); + this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); + + RpcHandler rpcHandler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + assertEquals("Ping", JavaUtils.bytesToString(message)); + callback.onSuccess(JavaUtils.stringToBytes("Pong")); + } + + @Override + public StreamManager getStreamManager() { + return null; + } + }; + + this.ctx = new TransportContext(conf, rpcHandler); + } + + void createServer(String secret) throws Exception { + createServer(secret, true); + } + + void createServer(String secret, boolean enableAes) throws Exception { + TransportServerBootstrap introspector = (channel, rpcHandler) -> { + this.serverChannel = channel; + if (rpcHandler instanceof AuthRpcHandler) { + this.authRpcHandler = (AuthRpcHandler) rpcHandler; + } + return rpcHandler; + }; + SecretKeyHolder keyHolder = createKeyHolder(secret); + TransportServerBootstrap auth = enableAes ? new AuthServerBootstrap(conf, keyHolder) + : new SaslServerBootstrap(conf, keyHolder); + this.server = ctx.createServer(Arrays.asList(auth, introspector)); + } + + void createClient(String secret) throws Exception { + createClient(secret, true); + } + + void createClient(String secret, boolean enableAes) throws Exception { + TransportConf clientConf = enableAes ? conf + : new TransportConf("rpc", MapConfigProvider.EMPTY); + List bootstraps = Arrays.asList( + new AuthClientBootstrap(clientConf, appId, createKeyHolder(secret))); + this.client = ctx.createClientFactory(bootstraps) + .createClient(TestUtils.getLocalHost(), server.getPort()); + } + + void close() { + if (client != null) { + client.close(); + } + if (server != null) { + server.close(); + } + } + + private SecretKeyHolder createKeyHolder(String secret) { + SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); + when(keyHolder.getSaslUser(anyString())).thenReturn(appId); + when(keyHolder.getSecretKey(anyString())).thenReturn(secret); + return keyHolder; + } + + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java new file mode 100644 index 0000000000000..a90ff247da4fc --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthMessagesSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import static org.junit.Assert.*; + +import org.apache.spark.network.protocol.Encodable; + +public class AuthMessagesSuite { + + private static int COUNTER = 0; + + private static String string() { + return String.valueOf(COUNTER++); + } + + private static byte[] byteArray() { + byte[] bytes = new byte[COUNTER++]; + for (int i = 0; i < bytes.length; i++) { + bytes[i] = (byte) COUNTER; + } return bytes; + } + + private static int integer() { + return COUNTER++; + } + + @Test + public void testClientChallenge() { + ClientChallenge msg = new ClientChallenge(string(), string(), integer(), string(), integer(), + byteArray(), byteArray()); + ClientChallenge decoded = ClientChallenge.decodeMessage(encode(msg)); + + assertEquals(msg.appId, decoded.appId); + assertEquals(msg.kdf, decoded.kdf); + assertEquals(msg.iterations, decoded.iterations); + assertEquals(msg.cipher, decoded.cipher); + assertEquals(msg.keyLength, decoded.keyLength); + assertTrue(Arrays.equals(msg.nonce, decoded.nonce)); + assertTrue(Arrays.equals(msg.challenge, decoded.challenge)); + } + + @Test + public void testServerResponse() { + ServerResponse msg = new ServerResponse(byteArray(), byteArray(), byteArray(), byteArray()); + ServerResponse decoded = ServerResponse.decodeMessage(encode(msg)); + assertTrue(Arrays.equals(msg.response, decoded.response)); + assertTrue(Arrays.equals(msg.nonce, decoded.nonce)); + assertTrue(Arrays.equals(msg.inputIv, decoded.inputIv)); + assertTrue(Arrays.equals(msg.outputIv, decoded.outputIv)); + } + + private ByteBuffer encode(Encodable msg) { + ByteBuf buf = Unpooled.buffer(); + msg.encode(buf); + return buf.nioBuffer(); + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 45cc03df435ac..6f15718bd8705 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -23,8 +23,11 @@ import java.io.File; import java.lang.reflect.Method; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeoutException; @@ -32,7 +35,7 @@ import java.util.concurrent.atomic.AtomicReference; import javax.security.sasl.SaslException; -import com.google.common.collect.Lists; +import com.google.common.collect.ImmutableMap; import com.google.common.io.ByteStreams; import com.google.common.io.Files; import io.netty.buffer.ByteBuf; @@ -42,8 +45,6 @@ import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.ChannelPromise; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; @@ -59,7 +60,7 @@ import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.util.ByteArrayWritableChannel; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; /** @@ -134,18 +135,15 @@ public void testSaslEncryption() throws Throwable { testBasicSasl(true); } - private void testBasicSasl(boolean encrypt) throws Throwable { + private static void testBasicSasl(boolean encrypt) throws Throwable { RpcHandler rpcHandler = mock(RpcHandler.class); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; - RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; - assertEquals("Ping", JavaUtils.bytesToString(message)); - cb.onSuccess(JavaUtils.stringToBytes("Pong")); - return null; - } - }) + doAnswer(invocation -> { + ByteBuffer message = (ByteBuffer) invocation.getArguments()[1]; + RpcResponseCallback cb = (RpcResponseCallback) invocation.getArguments()[2]; + assertEquals("Ping", JavaUtils.bytesToString(message)); + cb.onSuccess(JavaUtils.stringToBytes("Pong")); + return null; + }) .when(rpcHandler) .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class)); @@ -224,7 +222,7 @@ public void testEncryptedMessage() throws Exception { public void testEncryptedMessageChunking() throws Exception { File file = File.createTempFile("sasltest", ".txt"); try { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); byte[] data = new byte[8 * 1024]; new Random().nextBytes(data); @@ -252,21 +250,17 @@ public void testEncryptedMessageChunking() throws Exception { @Test public void testFileRegionEncryption() throws Exception { - final String blockSizeConf = "spark.network.sasl.maxEncryptedBlockSize"; - System.setProperty(blockSizeConf, "1k"); + Map testConf = ImmutableMap.of( + "spark.network.sasl.maxEncryptedBlockSize", "1k"); - final AtomicReference response = new AtomicReference<>(); - final File file = File.createTempFile("sasltest", ".txt"); + AtomicReference response = new AtomicReference<>(); + File file = File.createTempFile("sasltest", ".txt"); SaslTestCtx ctx = null; try { - final TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt())).thenAnswer(new Answer() { - @Override - public ManagedBuffer answer(InvocationOnMock invocation) { - return new FileSegmentManagedBuffer(conf, file, 0, file.length()); - } - }); + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(invocation -> + new FileSegmentManagedBuffer(conf, file, 0, file.length())); RpcHandler rpcHandler = mock(RpcHandler.class); when(rpcHandler.getStreamManager()).thenReturn(sm); @@ -275,20 +269,17 @@ public ManagedBuffer answer(InvocationOnMock invocation) { new Random().nextBytes(data); Files.write(data, file); - ctx = new SaslTestCtx(rpcHandler, true, false); + ctx = new SaslTestCtx(rpcHandler, true, false, testConf); - final CountDownLatch lock = new CountDownLatch(1); + CountDownLatch lock = new CountDownLatch(1); ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) { - response.set((ManagedBuffer) invocation.getArguments()[1]); - response.get().retain(); - lock.countDown(); - return null; - } - }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); + doAnswer(invocation -> { + response.set((ManagedBuffer) invocation.getArguments()[1]); + response.get().retain(); + lock.countDown(); + return null; + }).when(callback).onSuccess(anyInt(), any(ManagedBuffer.class)); ctx.client.fetchChunk(0, 0, callback); lock.await(10, TimeUnit.SECONDS); @@ -306,18 +297,15 @@ public Void answer(InvocationOnMock invocation) { if (response.get() != null) { response.get().release(); } - System.clearProperty(blockSizeConf); } } @Test public void testServerAlwaysEncrypt() throws Exception { - final String alwaysEncryptConfName = "spark.network.sasl.serverAlwaysEncrypt"; - System.setProperty(alwaysEncryptConfName, "true"); - SaslTestCtx ctx = null; try { - ctx = new SaslTestCtx(mock(RpcHandler.class), false, false); + ctx = new SaslTestCtx(mock(RpcHandler.class), false, false, + ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true")); fail("Should have failed to connect without encryption."); } catch (Exception e) { assertTrue(e.getCause() instanceof SaslException); @@ -325,7 +313,6 @@ public void testServerAlwaysEncrypt() throws Exception { if (ctx != null) { ctx.close(); } - System.clearProperty(alwaysEncryptConfName); } } @@ -389,7 +376,21 @@ private static class SaslTestCtx { boolean disableClientEncryption) throws Exception { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + this(rpcHandler, encrypt, disableClientEncryption, Collections.emptyMap()); + } + + SaslTestCtx( + RpcHandler rpcHandler, + boolean encrypt, + boolean disableClientEncryption, + Map extraConf) + throws Exception { + + Map testConf = ImmutableMap.builder() + .putAll(extraConf) + .put("spark.authenticate.enableSaslEncryption", String.valueOf(encrypt)) + .build(); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); SecretKeyHolder keyHolder = mock(SecretKeyHolder.class); when(keyHolder.getSaslUser(anyString())).thenReturn("user"); @@ -397,13 +398,14 @@ private static class SaslTestCtx { TransportContext ctx = new TransportContext(conf, rpcHandler); - this.checker = new EncryptionCheckerBootstrap(); + this.checker = new EncryptionCheckerBootstrap(SaslEncryption.ENCRYPTION_HANDLER_NAME); + this.server = ctx.createServer(Arrays.asList(new SaslServerBootstrap(conf, keyHolder), checker)); try { - List clientBootstraps = Lists.newArrayList(); - clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder, encrypt)); + List clientBootstraps = new ArrayList<>(); + clientBootstraps.add(new SaslClientBootstrap(conf, "user", keyHolder)); if (disableClientEncryption) { clientBootstraps.add(new EncryptionDisablerBootstrap()); } @@ -437,22 +439,22 @@ private static class EncryptionCheckerBootstrap extends ChannelOutboundHandlerAd implements TransportServerBootstrap { boolean foundEncryptionHandler; + String encryptHandlerName; + + EncryptionCheckerBootstrap(String encryptHandlerName) { + this.encryptHandlerName = encryptHandlerName; + } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { if (!foundEncryptionHandler) { foundEncryptionHandler = - ctx.channel().pipeline().get(SaslEncryption.ENCRYPTION_HANDLER_NAME) != null; + ctx.channel().pipeline().get(encryptHandlerName) != null; } ctx.write(msg, promise); } - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - super.handlerRemoved(ctx); - } - @Override public RpcHandler doBootstrap(Channel channel, RpcHandler rpcHandler) { channel.pipeline().addFirst("encryptionChecker", this); diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/CryptoUtilsSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/CryptoUtilsSuite.java new file mode 100644 index 0000000000000..2b45d1e39713c --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/util/CryptoUtilsSuite.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.Map; +import java.util.Properties; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; +import static org.junit.Assert.*; + +public class CryptoUtilsSuite { + + @Test + public void testConfConversion() { + String prefix = "my.prefix.commons.config."; + + String confKey1 = prefix + "a.b.c"; + String confVal1 = "val1"; + String cryptoKey1 = CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX + "a.b.c"; + + String confKey2 = prefix.substring(0, prefix.length() - 1) + "A.b.c"; + String confVal2 = "val2"; + String cryptoKey2 = CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX + "A.b.c"; + + Map conf = ImmutableMap.of( + confKey1, confVal1, + confKey2, confVal2); + + Properties cryptoConf = CryptoUtils.toCryptoConf(prefix, conf.entrySet()); + + assertEquals(confVal1, cryptoConf.getProperty(cryptoKey1)); + assertFalse(cryptoConf.containsKey(cryptoKey2)); + } + +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index d4de4a941d480..b53e41303751c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -28,8 +28,6 @@ import io.netty.channel.ChannelHandlerContext; import org.junit.AfterClass; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -52,7 +50,7 @@ public void testFrameDecoding() throws Exception { @Test public void testInterception() throws Exception { - final int interceptedReads = 3; + int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); ChannelHandlerContext ctx = mockChannelHandlerContext(); @@ -84,22 +82,19 @@ public void testInterception() throws Exception { public void testRetainedFrames() throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); - final AtomicInteger count = new AtomicInteger(); - final List retained = new ArrayList<>(); + AtomicInteger count = new AtomicInteger(); + List retained = new ArrayList<>(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock in) { - // Retain a few frames but not others. - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - if (count.incrementAndGet() % 2 == 0) { - retained.add(buf); - } else { - buf.release(); - } - return null; + when(ctx.fireChannelRead(any())).thenAnswer(in -> { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); } + return null; }); ByteBuf data = createAndFeedFrames(100, decoder, ctx); @@ -150,12 +145,6 @@ public void testEmptyFrame() throws Exception { testInvalidFrame(8); } - @Test(expected = IllegalArgumentException.class) - public void testLargeFrame() throws Exception { - // Frame length includes the frame size field, so need to add a few more bytes. - testInvalidFrame(Integer.MAX_VALUE + 9); - } - /** * Creates a number of randomly sized frames and feed them to the given decoder, verifying * that the frames were read. @@ -210,13 +199,10 @@ private void testInvalidFrame(long size) throws Exception { private ChannelHandlerContext mockChannelHandlerContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock in) { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); - return null; - } + when(ctx.fireChannelRead(any())).thenAnswer(in -> { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; }); return ctx; } diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 511e1f29de368..2de882adcb582 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -70,6 +70,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + log4j log4j diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 6e02430a8edb8..c0f1da50f5e65 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; import com.codahale.metrics.Gauge; @@ -30,7 +30,6 @@ import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,14 +92,25 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - List blocks = Lists.newArrayList(); - long totalBlockSize = 0; - for (String blockId : msg.blockIds) { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, blockId); - totalBlockSize += block != null ? block.size() : 0; - blocks.add(block); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + Iterator iter = new Iterator() { + private int index = 0; + + @Override + public boolean hasNext() { + return index < msg.blockIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, + msg.blockIds[index]); + index++; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + }; + + long streamId = streamManager.registerStream(client.getClientId(), iter); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -109,7 +119,6 @@ protected void handleMessage( getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - metrics.blockTransferRateBytes.mark(totalBlockSize); } finally { responseDelayContext.stop(); } @@ -190,12 +199,8 @@ private ShuffleMetrics() { allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis); allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis); allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); - allMetrics.put("registeredExecutorsSize", new Gauge() { - @Override - public Integer getValue() { - return blockManager.getRegisteredExecutorsSize(); - } - }); + allMetrics.put("registeredExecutorsSize", + (Gauge) () -> blockManager.getRegisteredExecutorsSize()); } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 25e9abde708d6..62d58aba4c1e7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -205,12 +205,7 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); // Execute the actual deletion in a different thread, as it may take some time. - directoryCleaner.execute(new Runnable() { - @Override - public void run() { - deleteExecutorDirs(executor.localDirs); - } - }); + directoryCleaner.execute(() -> deleteExecutorDirs(executor.localDirs)); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 772fb88325b35..2c5827bf7dc56 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -21,7 +21,6 @@ import java.nio.ByteBuffer; import java.util.List; -import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,7 +29,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -47,8 +46,7 @@ public class ExternalShuffleClient extends ShuffleClient { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); private final TransportConf conf; - private final boolean saslEnabled; - private final boolean saslEncryptionEnabled; + private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; protected TransportClientFactory clientFactory; @@ -61,15 +59,10 @@ public class ExternalShuffleClient extends ShuffleClient { public ExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - Preconditions.checkArgument( - !saslEncryptionEnabled || saslEnabled, - "SASL encryption can only be enabled if SASL is also enabled."); + boolean authEnabled) { this.conf = conf; this.secretKeyHolder = secretKeyHolder; - this.saslEnabled = saslEnabled; - this.saslEncryptionEnabled = saslEncryptionEnabled; + this.authEnabled = authEnabled; } protected void checkInit() { @@ -81,31 +74,27 @@ public void init(String appId) { this.appId = appId; TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List bootstraps = Lists.newArrayList(); - if (saslEnabled) { - bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } clientFactory = context.createClientFactory(bootstraps); } @Override public void fetchBlocks( - final String host, - final int port, - final String execId, + String host, + int port, + String execId, String[] blockIds, BlockFetchingListener listener) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = - new RetryingBlockFetcher.BlockFetchStarter() { - @Override - public void createAndStart(String[] blockIds, BlockFetchingListener listener) - throws IOException { + (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); - } - }; + new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1).start(); + }; int maxRetries = conf.maxIORetries(); if (maxRetries > 0) { @@ -136,14 +125,11 @@ public void registerWithShuffleServer( String host, int port, String execId, - ExecutorShuffleInfo executorInfo) throws IOException { + ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { checkInit(); - TransportClient client = clientFactory.createUnmanagedClient(host, port); - try { + try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); - } finally { - client.close(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 72bd0f803da33..f309dda8afca6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -57,7 +57,8 @@ public interface BlockFetchStarter { * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection * issues. */ - void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException, InterruptedException; } /** Shared executor service used for waiting and retrying. */ @@ -163,12 +164,9 @@ private synchronized void initiateRetry() { logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); - executorService.submit(new Runnable() { - @Override - public void run() { - Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); - fetchAllOutstanding(); - } + executorService.submit(() -> { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); }); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 42cedd9943150..dbc1010847fb1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -60,16 +60,15 @@ public class MesosExternalShuffleClient extends ExternalShuffleClient { public MesosExternalShuffleClient( TransportConf conf, SecretKeyHolder secretKeyHolder, - boolean saslEnabled, - boolean saslEncryptionEnabled) { - super(conf, secretKeyHolder, saslEnabled, saslEncryptionEnabled); + boolean authEnabled) { + super(conf, secretKeyHolder, authEnabled); } public void registerDriverWithShuffleService( String host, int port, long heartbeatTimeoutMs, - long heartbeatIntervalMs) throws IOException { + long heartbeatIntervalMs) throws IOException, InterruptedException { checkInit(); ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 6ba937dddb2a7..c0e170e5b9353 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -19,11 +19,11 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; -import com.google.common.collect.Lists; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -38,7 +38,6 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; @@ -55,7 +54,7 @@ import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class SaslIntegrationSuite { @@ -73,7 +72,7 @@ public class SaslIntegrationSuite { @BeforeClass public static void beforeAll() throws IOException { - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); context = new TransportContext(conf, new TestRpcHandler()); secretKeyHolder = mock(SecretKeyHolder.class); @@ -103,10 +102,9 @@ public void afterEach() { } @Test - public void testGoodClient() throws IOException { + public void testGoodClient() throws IOException, InterruptedException { clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); String msg = "Hello, World!"; @@ -120,8 +118,7 @@ public void testBadClient() { when(badKeyHolder.getSaslUser(anyString())).thenReturn("other-app"); when(badKeyHolder.getSecretKey(anyString())).thenReturn("wrong-password"); clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "unknown-app", badKeyHolder))); try { // Bootstrap should fail on startup. @@ -133,9 +130,8 @@ public void testBadClient() { } @Test - public void testNoSaslClient() throws IOException { - clientFactory = context.createClientFactory( - Lists.newArrayList()); + public void testNoSaslClient() throws IOException, InterruptedException { + clientFactory = context.createClientFactory(new ArrayList<>()); TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); try { @@ -159,15 +155,11 @@ public void testNoSaslServer() { RpcHandler handler = new TestRpcHandler(); TransportContext context = new TransportContext(conf, handler); clientFactory = context.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); - TransportServer server = context.createServer(); - try { + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + try (TransportServer server = context.createServer()) { clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); } catch (Exception e) { assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); - } finally { - server.close(); } } @@ -191,14 +183,13 @@ public void testAppIsolation() throws Exception { try { // Create a client, and make a request to fetch blocks from a different app. clientFactory = blockServerContext.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder))); client1 = clientFactory.createClient(TestUtils.getLocalHost(), blockServer.getPort()); - final AtomicReference exception = new AtomicReference<>(); + AtomicReference exception = new AtomicReference<>(); - final CountDownLatch blockFetchLatch = new CountDownLatch(1); + CountDownLatch blockFetchLatch = new CountDownLatch(1); BlockFetchingListener listener = new BlockFetchingListener() { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { @@ -235,12 +226,11 @@ public void onBlockFetchFailure(String blockId, Throwable t) { // Create a second client, authenticated with a different app ID, and try to read from // the stream created for the previous app. clientFactory2 = blockServerContext.createClientFactory( - Lists.newArrayList( - new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); + Arrays.asList(new SaslClientBootstrap(conf, "app-2", secretKeyHolder))); client2 = clientFactory2.createClient(TestUtils.getLocalHost(), blockServer.getPort()); - final CountDownLatch chunkReceivedLatch = new CountDownLatch(1); + CountDownLatch chunkReceivedLatch = new CountDownLatch(1); ChunkReceivedCallback callback = new ChunkReceivedCallback() { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { @@ -284,7 +274,7 @@ public StreamManager getStreamManager() { } } - private void checkSecurityException(Throwable t) { + private static void checkSecurityException(Throwable t) { assertNotNull("No exception was caught.", t); assertTrue("Expected SecurityException.", t.getMessage().contains(SecurityException.class.getName())); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index c036bc2e8d256..4d48b18970386 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -88,12 +88,10 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); - verify(callback, never()).onFailure((Throwable) any()); + verify(callback, never()).onFailure(any()); StreamHandle handle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); @@ -107,6 +105,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 35d6346474d5d..bc97594903bef 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -25,7 +25,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.io.CharStreams; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; import org.junit.AfterClass; @@ -42,7 +42,7 @@ public class ExternalShuffleBlockResolverSuite { private static TestShuffleDataContext dataContext; private static final TransportConf conf = - new TransportConf("shuffle", new SystemPropertyConfigProvider()); + new TransportConf("shuffle", MapConfigProvider.EMPTY); @BeforeClass public static void beforeAll() throws IOException { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index bdd218db69b54..47c087088a8a2 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -29,14 +29,14 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ExternalShuffleCleanupSuite { // Same-thread Executor used to ensure cleanup happens synchronously in test thread. private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - private TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; @Test @@ -60,12 +60,10 @@ public void noCleanupAndCleanup() throws IOException { public void cleanupUsesExecutor() throws IOException { TestShuffleDataContext dataContext = createSomeData(); - final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + AtomicBoolean cleanupCalled = new AtomicBoolean(false); // Executor which does nothing to ensure we're actually using it. - Executor noThreadExecutor = new Executor() { - @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } - }; + Executor noThreadExecutor = runnable -> cleanupCalled.set(true); ExternalShuffleBlockResolver manager = new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 552b5366c5930..7a33b6821792c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; @@ -28,7 +29,7 @@ import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; -import com.google.common.collect.Lists; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import org.junit.After; import org.junit.AfterClass; @@ -43,7 +44,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ExternalShuffleIntegrationSuite { @@ -84,7 +85,7 @@ public static void beforeAll() throws IOException { dataContext0.create(); dataContext0.insertSortShuffleData(0, 0, exec0Blocks); - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); @@ -115,12 +116,16 @@ public void releaseBuffers() { // Fetch a set of blocks from a pre-registered executor. private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { - return fetchBlocks(execId, blockIds, server.getPort()); + return fetchBlocks(execId, blockIds, conf, server.getPort()); } // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, // to allow connecting to invalid servers. - private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + private FetchResult fetchBlocks( + String execId, + String[] blockIds, + TransportConf clientConf, + int port) throws Exception { final FetchResult res = new FetchResult(); res.successBlocks = Collections.synchronizedSet(new HashSet()); res.failedBlocks = Collections.synchronizedSet(new HashSet()); @@ -128,7 +133,7 @@ private FetchResult fetchBlocks(String execId, String[] blockIds, int port) thro final Semaphore requestsRemaining = new Semaphore(0); - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); + ExternalShuffleClient client = new ExternalShuffleClient(clientConf, null, false); client.init(APP_ID); client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @@ -168,7 +173,7 @@ public void testFetchOneSort() throws Exception { FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); assertTrue(exec0Fetch.failedBlocks.isEmpty()); - assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks[0])); exec0Fetch.releaseBuffers(); } @@ -180,7 +185,7 @@ public void testFetchThreeSort() throws Exception { assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), exec0Fetch.successBlocks); assertTrue(exec0Fetch.failedBlocks.isEmpty()); - assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks)); exec0Fetch.releaseBuffers(); } @@ -211,9 +216,8 @@ public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test @@ -227,27 +231,24 @@ public void testFetchUnregisteredExecutor() throws Exception { @Test public void testFetchNoServer() throws Exception { - System.setProperty("spark.shuffle.io.maxRetries", "0"); - try { - registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-0", - new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); - } finally { - System.clearProperty("spark.shuffle.io.maxRetries"); - } + TransportConf clientConf = new TransportConf("shuffle", + new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", "0"))); + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); } - private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) - throws IOException { - ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, false); + private static void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException, InterruptedException { + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); client.init(APP_ID); client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), executorId, executorInfo); } - private void assertBufferListsEqual(List list0, List list1) + private static void assertBufferListsEqual(List list0, List list1) throws Exception { assertEquals(list0.size(), list1.size()); for (int i = 0; i < list0.size(); i ++) { @@ -255,7 +256,8 @@ private void assertBufferListsEqual(List list0, List list } } - private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + private static void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) + throws Exception { ByteBuffer nio0 = buffer0.nioByteBuffer(); ByteBuffer nio1 = buffer1.nioByteBuffer(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index a0f69ca29a280..bf20c577ed420 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.util.Arrays; +import com.google.common.collect.ImmutableMap; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -33,12 +34,12 @@ import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; public class ExternalShuffleSecuritySuite { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); TransportServer server; @Before @@ -59,7 +60,7 @@ public void afterEach() { } @Test - public void testValid() throws IOException { + public void testValid() throws IOException, InterruptedException { validate("my-app-id", "secret", false); } @@ -82,14 +83,21 @@ public void testBadSecret() { } @Test - public void testEncryption() throws IOException { + public void testEncryption() throws IOException, InterruptedException { validate("my-app-id", "secret", true); } /** Creates an ExternalShuffleClient and attempts to register with the server. */ - private void validate(String appId, String secretKey, boolean encrypt) throws IOException { + private void validate(String appId, String secretKey, boolean encrypt) + throws IOException, InterruptedException { + TransportConf testConf = conf; + if (encrypt) { + testConf = new TransportConf("shuffle", new MapConfigProvider( + ImmutableMap.of("spark.authenticate.enableSaslEncryption", "true"))); + } + ExternalShuffleClient client = - new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true, encrypt); + new ExternalShuffleClient(testConf, new TestSecretKeyHolder(appId, secretKey), true); client.init(appId); // Registration either succeeds or throws an exception. client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 2590b9ce4c1f1..3e51fea3cf0e5 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -25,8 +25,6 @@ import com.google.common.collect.Maps; import io.netty.buffer.Unpooled; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -85,8 +83,8 @@ public void testFailure() { // Each failure will cause a failure to be invoked in all remaining block fetches. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); - verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any()); } @Test @@ -100,15 +98,15 @@ public void testFailureAndSuccess() { // We may call both success and failure for the same block. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); - verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any()); } @Test public void testEmptyBlockFetch() { try { - fetchBlocks(Maps.newLinkedHashMap()); + fetchBlocks(Maps.newLinkedHashMap()); fail(); } catch (IllegalArgumentException e) { assertEquals("Zero-sized blockIds array", e.getMessage()); @@ -123,52 +121,46 @@ public void testEmptyBlockFetch() { * * If a block's buffer is "null", an exception will be thrown instead. */ - private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { + private static BlockFetchingListener fetchBlocks(LinkedHashMap blocks) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); - // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( - (ByteBuffer) invocationOnMock.getArguments()[0]); - RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); - assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); - return null; - } + // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123 + doAnswer(invocationOnMock -> { + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer( + (ByteBuffer) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteBuffer()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); + return null; }).when(client).sendRpc(any(ByteBuffer.class), any(RpcResponseCallback.class)); // Respond to each chunk request with a single buffer from our blocks array. - final AtomicInteger expectedChunkIndex = new AtomicInteger(0); - final Iterator blockIterator = blocks.values().iterator(); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - try { - long streamId = (Long) invocation.getArguments()[0]; - int myChunkIndex = (Integer) invocation.getArguments()[1]; - assertEquals(123, streamId); - assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); - - ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; - ManagedBuffer result = blockIterator.next(); - if (result != null) { - callback.onSuccess(myChunkIndex, result); - } else { - callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); - } - } catch (Exception e) { - e.printStackTrace(); - fail("Unexpected failure"); + AtomicInteger expectedChunkIndex = new AtomicInteger(0); + Iterator blockIterator = blocks.values().iterator(); + doAnswer(invocation -> { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); } - return null; + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); } - }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + return null; + }).when(client).fetchChunk(anyLong(), anyInt(), any()); fetcher.start(); return listener; diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index 91882e3b3bcd5..a530e16734db4 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -27,10 +27,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; -import org.junit.After; -import org.junit.Before; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; import org.mockito.stubbing.Stubber; @@ -39,7 +36,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.MapConfigProvider; import org.apache.spark.network.util.TransportConf; import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; @@ -53,20 +50,8 @@ public class RetryingBlockFetcherSuite { ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); - @Before - public void beforeEach() { - System.setProperty("spark.shuffle.io.maxRetries", "2"); - System.setProperty("spark.shuffle.io.retryWait", "0"); - } - - @After - public void afterEach() { - System.clearProperty("spark.shuffle.io.maxRetries"); - System.clearProperty("spark.shuffle.io.retryWait"); - } - @Test - public void testNoFailures() throws IOException { + public void testNoFailures() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -85,7 +70,7 @@ public void testNoFailures() throws IOException { } @Test - public void testUnrecoverableFailure() throws IOException { + public void testUnrecoverableFailure() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -98,13 +83,13 @@ public void testUnrecoverableFailure() throws IOException { performInteractions(interactions, listener); - verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchFailure(eq("b0"), any()); verify(listener).onBlockFetchSuccess("b1", block1); verifyNoMoreInteractions(listener); } @Test - public void testSingleIOExceptionOnFirst() throws IOException { + public void testSingleIOExceptionOnFirst() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -127,7 +112,7 @@ public void testSingleIOExceptionOnFirst() throws IOException { } @Test - public void testSingleIOExceptionOnSecond() throws IOException { + public void testSingleIOExceptionOnSecond() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -149,7 +134,7 @@ public void testSingleIOExceptionOnSecond() throws IOException { } @Test - public void testTwoIOExceptions() throws IOException { + public void testTwoIOExceptions() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -177,7 +162,7 @@ public void testTwoIOExceptions() throws IOException { } @Test - public void testThreeIOExceptions() throws IOException { + public void testThreeIOExceptions() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -204,12 +189,12 @@ public void testThreeIOExceptions() throws IOException { performInteractions(interactions, listener); verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), any()); verifyNoMoreInteractions(listener); } @Test - public void testRetryAndUnrecoverable() throws IOException { + public void testRetryAndUnrecoverable() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); List> interactions = Arrays.asList( @@ -234,7 +219,7 @@ public void testRetryAndUnrecoverable() throws IOException { performInteractions(interactions, listener); verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), any()); verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); verifyNoMoreInteractions(listener); } @@ -252,48 +237,48 @@ public void testRetryAndUnrecoverable() throws IOException { @SuppressWarnings("unchecked") private static void performInteractions(List> interactions, BlockFetchingListener listener) - throws IOException { + throws IOException, InterruptedException { - TransportConf conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); + MapConfigProvider provider = new MapConfigProvider(ImmutableMap.of( + "spark.shuffle.io.maxRetries", "2", + "spark.shuffle.io.retryWait", "0")); + TransportConf conf = new TransportConf("shuffle", provider); BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); Stubber stub = null; // Contains all blockIds that are referenced across all interactions. - final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + LinkedHashSet blockIds = Sets.newLinkedHashSet(); - for (final Map interaction : interactions) { + for (Map interaction : interactions) { blockIds.addAll(interaction.keySet()); - Answer answer = new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - try { - // Verify that the RetryingBlockFetcher requested the expected blocks. - String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; - String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); - assertArrayEquals(desiredBlockIds, requestedBlockIds); - - // Now actually invoke the success/failure callbacks on each block. - BlockFetchingListener retryListener = - (BlockFetchingListener) invocationOnMock.getArguments()[1]; - for (Map.Entry block : interaction.entrySet()) { - String blockId = block.getKey(); - Object blockValue = block.getValue(); - - if (blockValue instanceof ManagedBuffer) { - retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); - } else if (blockValue instanceof Exception) { - retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); - } else { - fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); - } + Answer answer = invocationOnMock -> { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); } - return null; - } catch (Throwable e) { - e.printStackTrace(); - throw e; } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; } }; @@ -306,7 +291,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } assertNotNull(stub); - stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + stub.when(fetchStarter).createAndStart(any(), anyObject()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); } diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 606ad15739617..a8488d8d1b704 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -50,6 +50,17 @@ spark-tags_${scala.binary.version} + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.hadoop diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index ea726e3c8240e..fd50e3a4bfb9b 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; -import java.nio.file.Files; import java.util.List; import java.util.Map; @@ -45,7 +44,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.sasl.SaslServerBootstrap; +import org.apache.spark.network.crypto.AuthServerBootstrap; import org.apache.spark.network.sasl.ShuffleSecretManager; import org.apache.spark.network.server.TransportServer; import org.apache.spark.network.server.TransportServerBootstrap; @@ -172,7 +171,7 @@ protected void serviceInit(Configuration conf) throws Exception { boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); if (authEnabled) { createSecretManager(); - bootstraps.add(new SaslServerBootstrap(transportConf, secretManager)); + bootstraps.add(new AuthServerBootstrap(transportConf, secretManager)); } int port = conf.getInt( @@ -340,9 +339,9 @@ protected Path getRecoveryPath(String fileName) { * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise * it will uses a YARN local dir. */ - protected File initRecoveryDb(String dbFileName) { + protected File initRecoveryDb(String dbName) { if (_recoveryPath != null) { - File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName); + File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName); if (recoveryFile.exists()) { return recoveryFile; } @@ -350,7 +349,7 @@ protected File initRecoveryDb(String dbFileName) { // db doesn't exist in recovery path go check local dirs for it String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir : localDirs) { - File f = new File(new Path(dir).toUri().getPath(), dbFileName); + File f = new File(new Path(dir).toUri().getPath(), dbName); if (f.exists()) { if (_recoveryPath == null) { // If NM recovery is not enabled, we should specify the recovery path using NM local @@ -363,17 +362,21 @@ protected File initRecoveryDb(String dbFileName) { // make sure to move all DBs to the recovery path from the old NM local dirs. // If another DB was initialized first just make sure all the DBs are in the same // location. - File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName); - if (!newLoc.equals(f)) { + Path newLoc = new Path(_recoveryPath, dbName); + Path copyFrom = new Path(f.toURI()); + if (!newLoc.equals(copyFrom)) { + logger.info("Moving " + copyFrom + " to: " + newLoc); try { - Files.move(f.toPath(), newLoc.toPath()); + // The move here needs to handle moving non-empty directories across NFS mounts + FileSystem fs = FileSystem.getLocal(_conf); + fs.rename(copyFrom, newLoc); } catch (Exception e) { // Fail to move recovery file to new path, just continue on with new DB location logger.error("Failed to move recovery file {} to the path {}", - dbFileName, _recoveryPath.toString(), e); + dbName, _recoveryPath.toString(), e); } } - return newLoc; + return new File(newLoc.toUri().getPath()); } } } @@ -381,7 +384,7 @@ protected File initRecoveryDb(String dbFileName) { _recoveryPath = new Path(localDirs[0]); } - return new File(_recoveryPath.toUri().getPath(), dbFileName); + return new File(_recoveryPath.toUri().getPath(), dbName); } /** diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java index 884861752e80d..8beb033699471 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -17,6 +17,7 @@ package org.apache.spark.network.yarn.util; +import java.util.Map; import java.util.NoSuchElementException; import org.apache.hadoop.conf.Configuration; @@ -39,4 +40,16 @@ public String get(String name) { } return value; } + + @Override + public String get(String name, String defaultValue) { + String value = conf.get(name); + return value == null ? defaultValue : value; + } + + @Override + public Iterable> getAll() { + return conf; + } + } diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 626f023a5b99c..6b81fc2b2b040 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + @@ -49,6 +61,7 @@ net.alchim31.maven scala-maven-plugin + 3.2.2 @@ -59,6 +72,7 @@ org.apache.maven.plugins maven-compiler-plugin + 3.6.1 diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java index 40fa20c4a3e37..f7c22dddb8cc0 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketch.java @@ -17,12 +17,13 @@ package org.apache.spark.util.sketch; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; /** - * A Count-min sketch is a probabilistic data structure used for summarizing streams of data in + * A Count-min sketch is a probabilistic data structure used for cardinality estimation using * sub-linear space. Currently, supported data types include: *
    *
  • {@link Byte}
  • @@ -173,6 +174,11 @@ public abstract CountMinSketch mergeInPlace(CountMinSketch other) */ public abstract void writeTo(OutputStream out) throws IOException; + /** + * Serializes this {@link CountMinSketch} and returns the serialized form. + */ + public abstract byte[] toByteArray() throws IOException; + /** * Reads in a {@link CountMinSketch} from an input stream. It is the caller's responsibility to * close the stream. @@ -181,6 +187,16 @@ public static CountMinSketch readFrom(InputStream in) throws IOException { return CountMinSketchImpl.readFrom(in); } + /** + * Reads in a {@link CountMinSketch} from a byte array. + */ + public static CountMinSketch readFrom(byte[] bytes) throws IOException { + InputStream in = new ByteArrayInputStream(bytes); + CountMinSketch cms = readFrom(in); + in.close(); + return cms; + } + /** * Creates a {@link CountMinSketch} with given {@code depth}, {@code width}, and random * {@code seed}. diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java index 2acbb247b13cd..045fec33a282a 100644 --- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java +++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java @@ -17,14 +17,7 @@ package org.apache.spark.util.sketch; -import java.io.DataInputStream; -import java.io.DataOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.OutputStream; -import java.io.Serializable; +import java.io.*; import java.util.Arrays; import java.util.Random; @@ -152,6 +145,8 @@ public void add(Object item) { public void add(Object item, long count) { if (item instanceof String) { addString((String) item, count); + } else if (item instanceof byte[]) { + addBinary((byte[]) item, count); } else { addLong(Utils.integralToLong(item), count); } @@ -234,6 +229,8 @@ private static int[] getHashBuckets(byte[] b, int hashCount, int max) { public long estimateCount(Object item) { if (item instanceof String) { return estimateCountForStringItem((String) item); + } else if (item instanceof byte[]) { + return estimateCountForBinaryItem((byte[]) item); } else { return estimateCountForLongItem(Utils.integralToLong(item)); } @@ -256,6 +253,15 @@ private long estimateCountForStringItem(String item) { return res; } + private long estimateCountForBinaryItem(byte[] item) { + long res = Long.MAX_VALUE; + int[] buckets = getHashBuckets(item, depth, width); + for (int i = 0; i < depth; ++i) { + res = Math.min(res, table[i][buckets[i]]); + } + return res; + } + @Override public CountMinSketch mergeInPlace(CountMinSketch other) throws IncompatibleMergeException { if (other == null) { @@ -314,6 +320,14 @@ public void writeTo(OutputStream out) throws IOException { } } + @Override + public byte[] toByteArray() throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + writeTo(out); + out.close(); + return out.toByteArray(); + } + public static CountMinSketchImpl readFrom(InputStream in) throws IOException { CountMinSketchImpl sketch = new CountMinSketchImpl(); sketch.readFrom0(in); diff --git a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala index b9c7f5c23a8fe..174eb01986c4f 100644 --- a/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala +++ b/common/sketch/src/test/scala/org/apache/spark/util/sketch/CountMinSketchSuite.scala @@ -25,9 +25,9 @@ import scala.util.Random import org.scalatest.FunSuite // scalastyle:ignore funsuite class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite - private val epsOfTotalCount = 0.0001 + private val epsOfTotalCount = 0.01 - private val confidence = 0.99 + private val confidence = 0.9 private val seed = 42 @@ -72,7 +72,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite if (ratio > epsOfTotalCount) 1 else 0 }.sum - 1D - numErrors.toDouble / numAllItems + 1.0 - (numErrors.toDouble / numAllItems) } assert( @@ -89,9 +89,7 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val numToMerge = 5 val numItemsPerSketch = 100000 - val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { - itemGenerator(r) - } + val perSketchItems = Array.fill(numToMerge, numItemsPerSketch) { itemGenerator(r) } val sketches = perSketchItems.map { items => val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) @@ -106,11 +104,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite val mergedSketch = sketches.reduce(_ mergeInPlace _) checkSerDe(mergedSketch) - val expectedSketch = { - val sketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) - perSketchItems.foreach(_.foreach(sketch.add)) - sketch - } + val expectedSketch = CountMinSketch.create(epsOfTotalCount, confidence, seed) + perSketchItems.foreach(_.foreach(expectedSketch.add)) perSketchItems.foreach { _.foreach { item => @@ -135,6 +130,8 @@ class CountMinSketchSuite extends FunSuite { // scalastyle:ignore funsuite testItemType[String]("String") { r => r.nextString(r.nextInt(20)) } + testItemType[Array[Byte]]("Byte array") { r => r.nextString(r.nextInt(60)).getBytes } + test("incompatible merge") { intercept[IncompatibleMergeException] { CountMinSketch.create(10, 10, 1).mergeInPlace(null) diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 1c60d510e5703..f7e586ee777e1 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -36,9 +36,9 @@ - org.scalatest - scalatest_${scala.binary.version} - compile + org.scala-lang + scala-library + ${scala.version} diff --git a/common/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/common/tags/src/test/java/org/apache/spark/tags/DockerTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/DockerTest.java rename to common/tags/src/test/java/org/apache/spark/tags/DockerTest.java diff --git a/common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedHiveTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/ExtendedHiveTest.java rename to common/tags/src/test/java/org/apache/spark/tags/ExtendedHiveTest.java diff --git a/common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java b/common/tags/src/test/java/org/apache/spark/tags/ExtendedYarnTest.java similarity index 100% rename from common/tags/src/main/java/org/apache/spark/tags/ExtendedYarnTest.java rename to common/tags/src/test/java/org/apache/spark/tags/ExtendedYarnTest.java diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 45af98d94ef91..680d0413b1616 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../../pom.xml @@ -39,6 +39,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + com.twitter chill_${scala.binary.version} @@ -86,6 +98,7 @@ net.alchim31.maven scala-maven-plugin + 3.2.2 @@ -96,6 +109,7 @@ org.apache.maven.plugins maven-compiler-plugin + 3.6.1 diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java index c7ea9085eba66..73577437ac506 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java @@ -20,7 +20,7 @@ import org.apache.spark.unsafe.Platform; /** - * Simulates Hive's hashing function at + * Simulates Hive's hashing function from Hive v1.2.1 * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() */ public class HiveHasher { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 671b8c7475943..aca6fca00c48b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -46,18 +46,23 @@ public final class Platform { private static final boolean unaligned; static { boolean _unaligned; - // use reflection to access unaligned field - try { - Class bitsClass = - Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); - Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); - unalignedMethod.setAccessible(true); - _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); - } catch (Throwable t) { - // We at least know x86 and x64 support unaligned access. - String arch = System.getProperty("os.arch", ""); - //noinspection DynamicRegexReplaceableByCompiledPattern - _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + String arch = System.getProperty("os.arch", ""); + if (arch.equals("ppc64le") || arch.equals("ppc64")) { + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it + _unaligned = true; + } else { + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64|aarch64)$"); + } } unaligned = _unaligned; } @@ -162,14 +167,9 @@ public static ByteBuffer allocateDirectBuffer(int size) { constructor.setAccessible(true); Field cleanerField = cls.getDeclaredField("cleaner"); cleanerField.setAccessible(true); - final long memory = allocateMemory(size); + long memory = allocateMemory(size); ByteBuffer buffer = (ByteBuffer) constructor.newInstance(memory, size); - Cleaner cleaner = Cleaner.create(buffer, new Runnable() { - @Override - public void run() { - freeMemory(memory); - } - }); + Cleaner cleaner = Cleaner.create(buffer, () -> freeMemory(memory)); cleanerField.set(buffer, cleaner); return buffer; } catch (Exception e) { diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 518ed6470a753..621f2c6bf3777 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -178,48 +178,52 @@ public static CalendarInterval fromSingleUnitString(String unit, String s) "Interval string does not match day-time format of 'd h:m:s.n': " + s); } else { try { - if (unit.equals("year")) { - int year = (int) toLongWithRange("year", m.group(1), - Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); - result = new CalendarInterval(year * 12, 0L); - - } else if (unit.equals("month")) { - int month = (int) toLongWithRange("month", m.group(1), - Integer.MIN_VALUE, Integer.MAX_VALUE); - result = new CalendarInterval(month, 0L); - - } else if (unit.equals("week")) { - long week = toLongWithRange("week", m.group(1), - Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK); - result = new CalendarInterval(0, week * MICROS_PER_WEEK); - - } else if (unit.equals("day")) { - long day = toLongWithRange("day", m.group(1), - Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); - result = new CalendarInterval(0, day * MICROS_PER_DAY); - - } else if (unit.equals("hour")) { - long hour = toLongWithRange("hour", m.group(1), - Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); - result = new CalendarInterval(0, hour * MICROS_PER_HOUR); - - } else if (unit.equals("minute")) { - long minute = toLongWithRange("minute", m.group(1), - Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); - result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); - - } else if (unit.equals("second")) { - long micros = parseSecondNano(m.group(1)); - result = new CalendarInterval(0, micros); - - } else if (unit.equals("millisecond")) { - long millisecond = toLongWithRange("millisecond", m.group(1), - Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI); - result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI); - - } else if (unit.equals("microsecond")) { - long micros = Long.parseLong(m.group(1)); - result = new CalendarInterval(0, micros); + switch (unit) { + case "year": + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + break; + case "month": + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + break; + case "week": + long week = toLongWithRange("week", m.group(1), + Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK); + result = new CalendarInterval(0, week * MICROS_PER_WEEK); + break; + case "day": + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + break; + case "hour": + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + break; + case "minute": + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + break; + case "second": { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + break; + } + case "millisecond": + long millisecond = toLongWithRange("millisecond", m.group(1), + Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI); + result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI); + break; + case "microsecond": { + long micros = Long.parseLong(m.group(1)); + result = new CalendarInterval(0, micros); + break; + } } } catch (Exception e) { throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); @@ -252,6 +256,10 @@ public static long parseSecondNano(String secondNano) throws IllegalArgumentExce public final int months; public final long microseconds; + public long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + public CalendarInterval(int months, long microseconds) { this.months = months; this.microseconds = microseconds; diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index e09a6b7d93a93..5437e998c085f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -147,6 +147,40 @@ public void writeTo(ByteBuffer buffer) { buffer.position(pos + numBytes); } + /** + * Returns a {@link ByteBuffer} wrapping the base object if it is a byte array + * or a copy of the data if the base object is not a byte array. + * + * Unlike getBytes this will not create a copy the array if this is a slice. + */ + @Nonnull + public ByteBuffer getByteBuffer() { + if (base instanceof byte[] && offset >= BYTE_ARRAY_OFFSET) { + final byte[] bytes = (byte[]) base; + + // the offset includes an object header... this is only needed for unsafe copies + final long arrayOffset = offset - BYTE_ARRAY_OFFSET; + + // verify that the offset and length points somewhere inside the byte array + // and that the offset can safely be truncated to a 32-bit integer + if ((long) bytes.length < arrayOffset + numBytes) { + throw new ArrayIndexOutOfBoundsException(); + } + + return ByteBuffer.wrap(bytes, (int) arrayOffset, numBytes); + } else { + return ByteBuffer.wrap(getBytes()); + } + } + + public void writeTo(OutputStream out) throws IOException { + final ByteBuffer bb = this.getByteBuffer(); + assert(bb.hasArray()); + + // similar to Utils.writeByteBuffer but without the spark-core dependency + out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()); + } + /** * Returns the number of bytes for a code point with the first byte as `b` * @param b The first byte of a code point @@ -816,6 +850,225 @@ public UTF8String translate(Map dict) { return fromString(sb.toString()); } + /** + * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + */ + public static class LongWrapper { + public long value = 0; + } + + /** + * Wrapper over `int` to allow result of parsing integer from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + * + * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of + * conversion from `long` to `int` + */ + public static class IntWrapper { + public int value = 0; + } + + /** + * Parses this UTF8String to long. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. + * + * This code is mostly copied from LazyLong.parseLong in Hive. + * + * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would + * be set in `toLongResult` + * @return true if the parsing was successful else false + */ + public boolean toLong(LongWrapper toLongResult) { + if (numBytes == 0) { + return false; + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + return false; + } + } + + final byte separator = '.'; + final int radix = 10; + final long stopValue = Long.MIN_VALUE / radix; + long result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop. + if (result < stopValue) { + return false; + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we + // can just use `result > 0` to check overflow. If result overflows, we should stop. + if (result > 0) { + return false; + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + return false; + } + } + + toLongResult.value = result; + return true; + } + + /** + * Parses this UTF8String to int. + * + * Note that, in this method we accumulate the result in negative format, and convert it to + * positive format at the end, if this string is not started with '-'. This is because min value + * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and + * Integer.MIN_VALUE is '-2147483648'. + * + * This code is mostly copied from LazyInt.parseInt in Hive. + * + * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance + * reasons, like Hive does. + * + * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would + * be set in `intWrapper` + * @return true if the parsing was successful else false + */ + public boolean toInt(IntWrapper intWrapper) { + if (numBytes == 0) { + return false; + } + + byte b = getByte(0); + final boolean negative = b == '-'; + int offset = 0; + if (negative || b == '+') { + offset++; + if (numBytes == 1) { + return false; + } + } + + final byte separator = '.'; + final int radix = 10; + final int stopValue = Integer.MIN_VALUE / radix; + int result = 0; + + while (offset < numBytes) { + b = getByte(offset); + offset++; + if (b == separator) { + // We allow decimals and will return a truncated integral in that case. + // Therefore we won't throw an exception here (checking the fractional + // part happens below.) + break; + } + + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + + // We are going to process the new digit and accumulate the result. However, before doing + // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then + // result * 10 will definitely be smaller than minValue, and we can stop + if (result < stopValue) { + return false; + } + + result = result * radix - digit; + // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), + // we can just use `result > 0` to check overflow. If result overflows, we should stop + if (result > 0) { + return false; + } + } + + // This is the case when we've encountered a decimal separator. The fractional + // part will not change the number, but we will verify that the fractional part + // is well formed. + while (offset < numBytes) { + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; + } + offset++; + } + + if (!negative) { + result = -result; + if (result < 0) { + return false; + } + } + intWrapper.value = result; + return true; + } + + public boolean toShort(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + short result = (short) intValue; + if (result == intValue) { + return true; + } + } + return false; + } + + public boolean toByte(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + byte result = (byte) intValue; + if (result == intValue) { + return true; + } + } + return false; + } + @Override public String toString() { return new String(getBytes(), StandardCharsets.UTF_8); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7f03686dcec41..c376371abdf90 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -17,15 +17,20 @@ package org.apache.spark.unsafe.types; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; +import java.util.*; import com.google.common.collect.ImmutableMap; +import org.apache.spark.unsafe.Platform; import org.junit.Test; import static org.junit.Assert.*; +import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; import static org.apache.spark.unsafe.types.UTF8String.*; public class UTF8StringSuite { @@ -499,4 +504,230 @@ public void soundex() { assertEquals(fromString("123").soundex(), fromString("123")); assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); } + + @Test + public void writeToOutputStreamUnderflow() throws IOException { + // offset underflow is apparently supported? + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + for (int i = 1; i <= Platform.BYTE_ARRAY_OFFSET; ++i) { + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET - i, test.length + i) + .writeTo(outputStream); + final ByteBuffer buffer = ByteBuffer.wrap(outputStream.toByteArray(), i, test.length); + assertEquals("01234567", StandardCharsets.UTF_8.decode(buffer).toString()); + outputStream.reset(); + } + } + + @Test + public void writeToOutputStreamSlice() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + for (int i = 0; i < test.length; ++i) { + for (int j = 0; j < test.length - i; ++j) { + UTF8String.fromAddress(test, Platform.BYTE_ARRAY_OFFSET + i, j) + .writeTo(outputStream); + + assertArrayEquals(Arrays.copyOfRange(test, i, i + j), outputStream.toByteArray()); + outputStream.reset(); + } + } + } + + @Test + public void writeToOutputStreamOverflow() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + final byte[] test = "01234567".getBytes(StandardCharsets.UTF_8); + + final HashSet offsets = new HashSet<>(); + for (int i = 0; i < 16; ++i) { + // touch more points around MAX_VALUE + offsets.add((long) Integer.MAX_VALUE - i); + // subtract off BYTE_ARRAY_OFFSET to avoid wrapping around to a negative value, + // which will hit the slower copy path instead of the optimized one + offsets.add(Long.MAX_VALUE - BYTE_ARRAY_OFFSET - i); + } + + for (long i = 1; i > 0L; i <<= 1) { + for (long j = 0; j < 32L; ++j) { + offsets.add(i + j); + } + } + + for (final long offset : offsets) { + try { + fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length) + .writeTo(outputStream); + + throw new IllegalStateException(Long.toString(offset)); + } catch (ArrayIndexOutOfBoundsException e) { + // ignore + } finally { + outputStream.reset(); + } + } + } + + @Test + public void writeToOutputStream() throws IOException { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + EMPTY_UTF8.writeTo(outputStream); + assertEquals("", outputStream.toString("UTF-8")); + outputStream.reset(); + + fromString("数据砖很重").writeTo(outputStream); + assertEquals( + "数据砖很重", + outputStream.toString("UTF-8")); + outputStream.reset(); + } + + @Test + public void writeToOutputStreamIntArray() throws IOException { + // verify that writes work on objects that are not byte arrays + final ByteBuffer buffer = StandardCharsets.UTF_8.encode("大千世界"); + buffer.position(0); + buffer.order(ByteOrder.nativeOrder()); + + final int length = buffer.limit(); + assertEquals(12, length); + + final int ints = length / 4; + final int[] array = new int[ints]; + + for (int i = 0; i < ints; ++i) { + array[i] = buffer.getInt(); + } + + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + fromAddress(array, Platform.INT_ARRAY_OFFSET, length) + .writeTo(outputStream); + assertEquals("大千世界", outputStream.toString("UTF-8")); + } + + @Test + public void testToShort() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (short) 1); + inputToExpectedOutput.put("+1", (short) 1); + inputToExpectedOutput.put("-1", (short) -1); + inputToExpectedOutput.put("0", (short) 0); + inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111); + inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + short value = (short) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper wrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper)); + assertEquals((short) entry.getValue(), wrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper)); + } + } + + @Test + public void testToByte() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (byte) 1); + inputToExpectedOutput.put("+1",(byte) 1); + inputToExpectedOutput.put("-1", (byte) -1); + inputToExpectedOutput.put("0", (byte) 0); + inputToExpectedOutput.put("111.12345678901234567890", (byte) 111); + inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + byte value = (byte) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper)); + assertEquals((byte) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper)); + } + } + + @Test + public void testToInt() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1); + inputToExpectedOutput.put("+1", 1); + inputToExpectedOutput.put("-1", -1); + inputToExpectedOutput.put("0", 0); + inputToExpectedOutput.put("11111.1234567", 11111); + inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + int value = rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper)); + assertEquals((int) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper)); + } + } + + @Test + public void testToLong() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1L); + inputToExpectedOutput.put("+1", 1L); + inputToExpectedOutput.put("-1", -1L); + inputToExpectedOutput.put("0", 0L); + inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L); + inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + long value = rand.nextLong(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + LongWrapper wrapper = new LongWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper)); + assertEquals((long) entry.getValue(), wrapper.value); + } + + List negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", + "1234567890123456789012345678901234"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); + } + } } diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 5c1e876ef9afc..b7c985ace69cf 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -25,18 +25,15 @@ # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public dns name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # Options read by executors and drivers running inside the cluster # - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node # - SPARK_PUBLIC_DNS, to set the public DNS name of the driver program -# - SPARK_CLASSPATH, default classpath entries to append # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos # Options read in YARN client mode # - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - SPARK_EXECUTOR_INSTANCES, Number of executors to start (Default: 2) # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) @@ -48,7 +45,6 @@ # - SPARK_WORKER_CORES, to set the number of cores to use on this machine # - SPARK_WORKER_MEMORY, to set how much total memory workers have to give executors (e.g. 1000m, 2g) # - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT, to use non-default ports for the worker -# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node # - SPARK_WORKER_DIR, to set the working directory of worker processes # - SPARK_WORKER_OPTS, to set config properties only for the worker (e.g. "-Dx=y") # - SPARK_DAEMON_MEMORY, to allocate to the master, worker and history server themselves (default: 1g). diff --git a/core/pom.xml b/core/pom.xml index eac99ab82a2e4..7f245b5b6384a 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.1.0-SNAPSHOT + 2.3.0-SNAPSHOT ../pom.xml @@ -33,6 +33,10 @@ Spark Project Core http://spark.apache.org/ + + org.apache.avro + avro + org.apache.avro avro-mapred @@ -337,6 +341,18 @@ org.apache.spark spark-tags_${scala.binary.version} + + + + org.apache.spark + spark-tags_${scala.binary.version} + test-jar + test + + org.apache.commons commons-crypto diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index 97eed611e8f9a..140c52fd12f94 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -30,96 +30,117 @@ */ public class SparkFirehoseListener implements SparkListenerInterface { - public void onEvent(SparkListenerEvent event) { } - - @Override - public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { - onEvent(stageCompleted); - } - - @Override - public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { - onEvent(stageSubmitted); - } - - @Override - public final void onTaskStart(SparkListenerTaskStart taskStart) { - onEvent(taskStart); - } - - @Override - public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { - onEvent(taskGettingResult); - } - - @Override - public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { - onEvent(taskEnd); - } - - @Override - public final void onJobStart(SparkListenerJobStart jobStart) { - onEvent(jobStart); - } - - @Override - public final void onJobEnd(SparkListenerJobEnd jobEnd) { - onEvent(jobEnd); - } - - @Override - public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { - onEvent(environmentUpdate); - } - - @Override - public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { - onEvent(blockManagerAdded); - } - - @Override - public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { - onEvent(blockManagerRemoved); - } - - @Override - public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { - onEvent(unpersistRDD); - } - - @Override - public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { - onEvent(applicationStart); - } - - @Override - public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { - onEvent(applicationEnd); - } - - @Override - public final void onExecutorMetricsUpdate( - SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { - onEvent(executorMetricsUpdate); - } - - @Override - public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { - onEvent(executorAdded); - } - - @Override - public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { - onEvent(executorRemoved); - } - - @Override - public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { - onEvent(blockUpdated); - } - - @Override - public void onOtherEvent(SparkListenerEvent event) { - onEvent(event); - } + public void onEvent(SparkListenerEvent event) { } + + @Override + public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) { + onEvent(stageCompleted); + } + + @Override + public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { + onEvent(stageSubmitted); + } + + @Override + public final void onTaskStart(SparkListenerTaskStart taskStart) { + onEvent(taskStart); + } + + @Override + public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { + onEvent(taskGettingResult); + } + + @Override + public final void onTaskEnd(SparkListenerTaskEnd taskEnd) { + onEvent(taskEnd); + } + + @Override + public final void onJobStart(SparkListenerJobStart jobStart) { + onEvent(jobStart); + } + + @Override + public final void onJobEnd(SparkListenerJobEnd jobEnd) { + onEvent(jobEnd); + } + + @Override + public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { + onEvent(environmentUpdate); + } + + @Override + public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { + onEvent(blockManagerAdded); + } + + @Override + public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { + onEvent(blockManagerRemoved); + } + + @Override + public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { + onEvent(unpersistRDD); + } + + @Override + public final void onApplicationStart(SparkListenerApplicationStart applicationStart) { + onEvent(applicationStart); + } + + @Override + public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { + onEvent(applicationEnd); + } + + @Override + public final void onExecutorMetricsUpdate( + SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { + onEvent(executorMetricsUpdate); + } + + @Override + public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { + onEvent(executorAdded); + } + + @Override + public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { + onEvent(executorRemoved); + } + + @Override + public final void onExecutorBlacklisted(SparkListenerExecutorBlacklisted executorBlacklisted) { + onEvent(executorBlacklisted); + } + + @Override + public final void onExecutorUnblacklisted( + SparkListenerExecutorUnblacklisted executorUnblacklisted) { + onEvent(executorUnblacklisted); + } + + @Override + public final void onNodeBlacklisted(SparkListenerNodeBlacklisted nodeBlacklisted) { + onEvent(nodeBlacklisted); + } + + @Override + public final void onNodeUnblacklisted(SparkListenerNodeUnblacklisted nodeUnblacklisted) { + onEvent(nodeUnblacklisted); + } + + @Override + public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { + onEvent(blockUpdated); + } + + @Override + public void onOtherEvent(SparkListenerEvent event) { + onEvent(event); + } } diff --git a/core/src/main/java/org/apache/spark/api/java/Optional.java b/core/src/main/java/org/apache/spark/api/java/Optional.java index ca7babc3f01c7..fd0f495ca29da 100644 --- a/core/src/main/java/org/apache/spark/api/java/Optional.java +++ b/core/src/main/java/org/apache/spark/api/java/Optional.java @@ -18,6 +18,7 @@ package org.apache.spark.api.java; import java.io.Serializable; +import java.util.Objects; import com.google.common.base.Preconditions; @@ -52,8 +53,8 @@ *
  • {@link #isPresent()}
  • *
* - *

{@code java.util.Optional} itself is not used at this time because the - * project does not require Java 8. Using {@code com.google.common.base.Optional} + *

{@code java.util.Optional} itself was not used because at the time, the + * project did not require Java 8. Using {@code com.google.common.base.Optional} * has in the past caused serious library version conflicts with Guava that can't * be resolved by shading. Hence this work-alike clone.

* @@ -171,7 +172,7 @@ public boolean equals(Object obj) { return false; } Optional other = (Optional) obj; - return value == null ? other.value == null : value.equals(other.value); + return Objects.equals(value, other.value); } @Override diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java index 07aebb75e8f4e..33bedf7ebcb07 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -24,6 +24,7 @@ * A function that returns zero or more output records from each grouping key and its values from 2 * Datasets. */ +@FunctionalInterface public interface CoGroupFunction extends Serializable { Iterator call(K key, Iterator left, Iterator right) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java index 576087b6f428e..2f23da5bfec1c 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFlatMapFunction.java @@ -23,6 +23,7 @@ /** * A function that returns zero or more records of type Double from each input record. */ +@FunctionalInterface public interface DoubleFlatMapFunction extends Serializable { Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java index bf16f791f906a..3c0291cf46240 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/DoubleFunction.java @@ -22,6 +22,7 @@ /** * A function that returns Doubles, and can be used to construct DoubleRDDs. */ +@FunctionalInterface public interface DoubleFunction extends Serializable { double call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java index 462ca3f6f6d19..a6f69f7cdca86 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -24,6 +24,7 @@ * * If the function returns true, the element is included in the returned Dataset. */ +@FunctionalInterface public interface FilterFunction extends Serializable { boolean call(T value) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 2d8ea6d1a5a7e..91d61292f167f 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -23,6 +23,7 @@ /** * A function that returns zero or more output records from each input record. */ +@FunctionalInterface public interface FlatMapFunction extends Serializable { Iterator call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index fc97b63f825d0..f9f2580b01f45 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -23,6 +23,7 @@ /** * A function that takes two inputs and returns zero or more output records. */ +@FunctionalInterface public interface FlatMapFunction2 extends Serializable { Iterator call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java index bae574ab5755d..6423c5d0fce56 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsFunction.java @@ -23,6 +23,7 @@ /** * A function that returns zero or more output records from each grouping key and its values. */ +@FunctionalInterface public interface FlatMapGroupsFunction extends Serializable { Iterator call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java index 07e54b28fa12c..2e6e90818d580 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java @@ -24,6 +24,7 @@ * * Spark will invoke the call function on each element in the input Dataset. */ +@FunctionalInterface public interface ForeachFunction extends Serializable { void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java index 4938a51bcd712..d8f55d0ae1dc0 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java @@ -23,6 +23,7 @@ /** * Base interface for a function used in Dataset's foreachPartition function. */ +@FunctionalInterface public interface ForeachPartitionFunction extends Serializable { void call(Iterator t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function.java b/core/src/main/java/org/apache/spark/api/java/function/Function.java index b9d9777a75651..8b2bbd501c498 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java @@ -24,6 +24,7 @@ * DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed * when mapping RDDs of other types. */ +@FunctionalInterface public interface Function extends Serializable { R call(T1 v1) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java index c86928dd05408..5c649d9de414d 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -22,6 +22,7 @@ /** * A zero-argument function that returns an R. */ +@FunctionalInterface public interface Function0 extends Serializable { R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function2.java b/core/src/main/java/org/apache/spark/api/java/function/Function2.java index a975ce3c68192..a7d9647095151 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function2.java @@ -22,6 +22,7 @@ /** * A two-argument function that takes arguments of type T1 and T2 and returns an R. */ +@FunctionalInterface public interface Function2 extends Serializable { R call(T1 v1, T2 v2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function3.java b/core/src/main/java/org/apache/spark/api/java/function/Function3.java index 6eecfb645a663..77acd21d4eff7 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function3.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function3.java @@ -22,6 +22,7 @@ /** * A three-argument function that takes arguments of type T1, T2 and T3 and returns an R. */ +@FunctionalInterface public interface Function3 extends Serializable { R call(T1 v1, T2 v2, T3 v3) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java index 9c35a22ca9d0f..d530ba446b3c2 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function4.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -22,6 +22,7 @@ /** * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. */ +@FunctionalInterface public interface Function4 extends Serializable { R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java index 3ae6ef44898e1..5efff943c8cdc 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -22,6 +22,7 @@ /** * Base interface for a map function used in Dataset's map function. */ +@FunctionalInterface public interface MapFunction extends Serializable { U call(T value) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java index faa59eabc8b4f..2c3d43afc0b3e 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupsFunction.java @@ -23,6 +23,7 @@ /** * Base interface for a map function used in GroupedDataset's mapGroup function. */ +@FunctionalInterface public interface MapGroupsFunction extends Serializable { R call(K key, Iterator values) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java index cf9945a215aff..68e8557c88d1b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -23,6 +23,7 @@ /** * Base interface for function used in Dataset's mapPartitions. */ +@FunctionalInterface public interface MapPartitionsFunction extends Serializable { Iterator call(Iterator input) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java index 51eed2e67b9fa..97bd2b37a059c 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -26,6 +26,7 @@ * A function that returns zero or more key-value pair records from each input record. The * key-value pairs are represented as scala.Tuple2 objects. */ +@FunctionalInterface public interface PairFlatMapFunction extends Serializable { Iterator> call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java index 2fdfa7184a3bd..34a7e4489a319 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java @@ -25,6 +25,7 @@ * A function that returns key-value pairs (Tuple2<K, V>), and can be used to * construct PairRDDs. */ +@FunctionalInterface public interface PairFunction extends Serializable { Tuple2 call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java index ee092d0058f44..d9029d85387ae 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java @@ -22,6 +22,7 @@ /** * Base interface for function used in Dataset's reduce. */ +@FunctionalInterface public interface ReduceFunction extends Serializable { T call(T v1, T v2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java index f30d42ee57966..aff2bc6e94fb3 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction.java @@ -22,6 +22,7 @@ /** * A function with no return value. */ +@FunctionalInterface public interface VoidFunction extends Serializable { void call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java index da9ae1c9c5cdc..ddb616241b244 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/VoidFunction2.java @@ -22,6 +22,7 @@ /** * A two-argument function that takes arguments of type T1 and T2 with no return value. */ +@FunctionalInterface public interface VoidFunction2 extends Serializable { void call(T1 v1, T2 v2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index f6d1288cb263d..ea5f1a9abf69b 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,8 +130,10 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } + //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } + //checkstyle.on: NoFinalizer } diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index fc1f3a80239ba..48cf4b9455e4d 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -60,8 +60,6 @@ protected long getUsed() { /** * Force spill during building. - * - * For testing. */ public void spill() throws IOException { spill(Long.MAX_VALUE, this); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 1a700aa37554e..5f91411749167 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -20,8 +20,12 @@ import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.Arrays; +import java.util.ArrayList; import java.util.BitSet; import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -144,23 +148,46 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { // spilling, avoid to have too many spilled files. if (got < required) { // Call spill() on other consumers to release memory + // Sort the consumers according their memory usage. So we avoid spilling the same consumer + // which is just spilled in last few times and re-spilling on it will produce many small + // spill files. + TreeMap> sortedConsumers = new TreeMap<>(); for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { - try { - long released = c.spill(required - got, consumer); - if (released > 0) { - logger.debug("Task {} released {} from {} for {}", taskAttemptId, - Utils.bytesToString(released), c, consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); - if (got >= required) { - break; - } + long key = c.getUsed(); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + list.add(c); + } + } + while (!sortedConsumers.isEmpty()) { + // Get the consumer using the least memory more than the remaining required memory. + Map.Entry> currentEntry = + sortedConsumers.ceilingEntry(required - got); + // No consumer has used memory more than the remaining required memory. + // Get the consumer of largest used memory. + if (currentEntry == null) { + currentEntry = sortedConsumers.lastEntry(); + } + List cList = currentEntry.getValue(); + MemoryConsumer c = cList.remove(cList.size() - 1); + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } + try { + long released = c.spill(required - got, consumer); + if (released > 0) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, + Utils.bytesToString(released), c, consumer); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); + if (got >= required) { + break; } - } catch (IOException e) { - logger.error("error while calling spill() on " + c, e); - throw new OutOfMemoryError("error while calling spill() on " + c + " : " - + e.getMessage()); } + } catch (IOException e) { + logger.error("error while calling spill() on " + c, e); + throw new OutOfMemoryError("error while calling spill() on " + c + " : " + + e.getMessage()); } } } @@ -378,14 +405,14 @@ public long cleanUpAllAllocatedMemory() { for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { // In case of failed task, it's normal to see leaked memory - logger.warn("leak " + Utils.bytesToString(c.getUsed()) + " memory from " + c); + logger.debug("unreleased " + Utils.bytesToString(c.getUsed()) + " memory from " + c); } } consumers.clear(); for (MemoryBlock page : pageTable) { if (page != null) { - logger.warn("leak a page: " + page + " in task " + taskAttemptId); + logger.debug("unreleased page: " + page + " in task " + taskAttemptId); memoryManager.tungstenMemoryAllocator().free(page); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a15559e55cbd..323a5d3c52831 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -52,8 +52,7 @@ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path * writes incoming records to separate files, one file per reduce partition, then concatenates these * per-partition files to form a single output file, regions of which are served to reducers. - * Records are not buffered in memory. This is essentially identical to - * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format + * Records are not buffered in memory. It writes output in a format * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}. *

* This write path is inefficient for shuffles with large numbers of reduce partitions because it @@ -61,7 +60,7 @@ * {@link SortShuffleManager} only selects this write path when *

    *
  • no Ordering is specified,
  • - *
  • no Aggregator is specific, and
  • + *
  • no Aggregator is specified, and
  • *
  • the number of partitions is less than * spark.shuffle.sort.bypassMergeThreshold.
  • *
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c434be7b1..8a1771848dee6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -264,6 +266,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -289,7 +292,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -320,9 +323,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, or in - * cases where users have explicitly disabled use of {@code transferTo} in order to work around - * kernel bugs. + * cases where the IO compression codec does not support concatenation of compressed data, when + * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in + * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -337,7 +340,11 @@ private long[] mergeSpillsWithFileStream( final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; - OutputStream mergedFileOutputStream = null; + + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( + new FileOutputStream(outputFile)); boolean threwException = true; try { @@ -345,34 +352,35 @@ private long[] mergeSpillsWithFileStream( spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() calls, so that we can close the higher + // level streams to make sure all data is really flushed and internal state is cleaned. + OutputStream partitionOutput = new CloseShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { - mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean innerThrewException = true; + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); try { - partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); - innerThrewException = false; + ByteStreams.copy(partitionInputStream, partitionOutput); } finally { - Closeables.close(partitionInputStream, innerThrewException); + partitionInputStream.close(); } } } - mergedFileOutputStream.flush(); - mergedFileOutputStream.close(); - partitionLengths[partition] = (outputFile.length() - initialFileLength); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { diff --git a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java index 9307eb93a5b20..dff4f5df68784 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/TaskSorting.java @@ -19,7 +19,9 @@ import org.apache.spark.util.EnumUtil; +import java.util.Collections; import java.util.HashSet; +import java.util.Locale; import java.util.Set; public enum TaskSorting { @@ -30,13 +32,11 @@ public enum TaskSorting { private final Set alternateNames; TaskSorting(String... names) { alternateNames = new HashSet<>(); - for (String n: names) { - alternateNames.add(n); - } + Collections.addAll(alternateNames, names); } public static TaskSorting fromString(String str) { - String lower = str.toLowerCase(); + String lower = str.toLowerCase(Locale.ROOT); for (TaskSorting t: values()) { if (t.alternateNames.contains(lower)) { return t; diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d2fcdea4f2cee..4bef21b6b4e4d 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -170,6 +170,8 @@ public final class BytesToBytesMap extends MemoryConsumer { private long peakMemoryUsedBytes = 0L; + private final int initialCapacity; + private final BlockManager blockManager; private final SerializerManager serializerManager; private volatile MapIterator destructiveIterator = null; @@ -202,6 +204,7 @@ public BytesToBytesMap( throw new IllegalArgumentException("Page size " + pageSizeBytes + " cannot exceed " + TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES); } + this.initialCapacity = initialCapacity; allocate(initialCapacity); } @@ -695,7 +698,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff if (numKeys == MAX_CAPACITY // The map could be reused from last spill (because of no enough memory to grow), // then we don't try to grow again if hit the `growthThreshold`. - || !canGrowArray && numKeys > growthThreshold) { + || !canGrowArray && numKeys >= growthThreshold) { return false; } @@ -739,7 +742,7 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff longArray.set(pos * 2 + 1, keyHashcode); isDefined = true; - if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) { + if (numKeys >= growthThreshold && longArray.size() < MAX_CAPACITY) { try { growAndRehash(); } catch (OutOfMemoryError oom) { @@ -902,12 +905,13 @@ public LongArray getArray() { public void reset() { numKeys = 0; numValues = 0; - longArray.zeroOut(); - + freeArray(longArray); while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); } + allocate(initialCapacity); + canGrowArray = true; currentPage = null; pageCursor = 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java index 404361734a55b..3dd318471008b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RadixSort.java @@ -17,6 +17,8 @@ package org.apache.spark.util.collection.unsafe.sort; +import com.google.common.primitives.Ints; + import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -40,14 +42,14 @@ public class RadixSort { * of always copying the data back to position zero for efficiency. */ public static int sort( - LongArray array, int numRecords, int startByteIndex, int endByteIndex, + LongArray array, long numRecords, int startByteIndex, int endByteIndex, boolean desc, boolean signed) { assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0"; assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 2 <= array.size(); - int inIndex = 0; - int outIndex = numRecords; + long inIndex = 0; + long outIndex = numRecords; if (numRecords > 0) { long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex); for (int i = startByteIndex; i <= endByteIndex; i++) { @@ -55,13 +57,13 @@ public static int sort( sortAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -78,14 +80,14 @@ public static int sort( * @param signed whether this is a signed (two's complement) sort (only applies to last byte). */ private static void sortAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed); Object baseObject = array.getBaseObject(); - long baseOffset = array.getBaseOffset() + inIndex * 8; - long maxOffset = baseOffset + numRecords * 8; + long baseOffset = array.getBaseOffset() + inIndex * 8L; + long maxOffset = baseOffset + numRecords * 8L; for (long offset = baseOffset; offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); int bucket = (int)((value >>> (byteIdx * 8)) & 0xff); @@ -106,13 +108,13 @@ private static void sortAtByte( * significant byte. If the byte does not need sorting the array will be null. */ private static long[][] getCounts( - LongArray array, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; // Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting. // If all the byte values at a particular index are the same we don't need to count it. long bitwiseMax = 0; long bitwiseMin = -1L; - long maxOffset = array.getBaseOffset() + numRecords * 8; + long maxOffset = array.getBaseOffset() + numRecords * 8L; Object baseObject = array.getBaseObject(); for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) { long value = Platform.getLong(baseObject, offset); @@ -146,18 +148,18 @@ private static long[][] getCounts( * @return the input counts array. */ private static long[] transformCountsToOffsets( - long[] counts, int numRecords, long outputOffset, int bytesPerRecord, + long[] counts, long numRecords, long outputOffset, long bytesPerRecord, boolean desc, boolean signed) { assert counts.length == 256; int start = signed ? 128 : 0; // output the negative records first (values 129-255). if (desc) { - int pos = numRecords; + long pos = numRecords; for (int i = start; i < start + 256; i++) { pos -= counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; } } else { - int pos = 0; + long pos = 0; for (int i = start; i < start + 256; i++) { long tmp = counts[i & 0xff]; counts[i & 0xff] = outputOffset + pos * bytesPerRecord; @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets( */ public static int sortKeyPrefixArray( LongArray array, - int startIndex, - int numRecords, + long startIndex, + long numRecords, int startByteIndex, int endByteIndex, boolean desc, @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray( assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7"; assert endByteIndex > startByteIndex; assert numRecords * 4 <= array.size(); - int inIndex = startIndex; - int outIndex = startIndex + numRecords * 2; + long inIndex = startIndex; + long outIndex = startIndex + numRecords * 2L; if (numRecords > 0) { long[][] counts = getKeyPrefixArrayCounts( array, startIndex, numRecords, startByteIndex, endByteIndex); @@ -196,13 +198,13 @@ public static int sortKeyPrefixArray( sortKeyPrefixArrayAtByte( array, numRecords, counts[i], i, inIndex, outIndex, desc, signed && i == endByteIndex); - int tmp = inIndex; + long tmp = inIndex; inIndex = outIndex; outIndex = tmp; } } } - return inIndex; + return Ints.checkedCast(inIndex); } /** @@ -210,7 +212,7 @@ public static int sortKeyPrefixArray( * getCounts with some added parameters but that seems to hurt in benchmarks. */ private static long[][] getKeyPrefixArrayCounts( - LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) { + LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) { long[][] counts = new long[8][]; long bitwiseMax = 0; long bitwiseMin = -1L; @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts( * Specialization of sortAtByte() for key-prefix arrays. */ private static void sortKeyPrefixArrayAtByte( - LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex, + LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex, boolean desc, boolean signed) { assert counts.length == 256; long[] offsets = transformCountsToOffsets( - counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed); + counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed); Object baseObject = array.getBaseObject(); long baseOffset = array.getBaseOffset() + inIndex * 8L; long maxOffset = baseOffset + numRecords * 16L; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index dcae4a34c4b0b..f312fa2b2ddd7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -37,7 +37,6 @@ import org.apache.spark.unsafe.UnsafeAlignedOffset; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; /** @@ -162,14 +161,9 @@ private UnsafeExternalSorter( // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator // does not fully consume the sorter's output (e.g. sort followed by limit). - taskContext.addTaskCompletionListener( - new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - cleanupResources(); - } - } - ); + taskContext.addTaskCompletionListener(context -> { + cleanupResources(); + }); } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 2a71e68adafad..c14c12664f5ab 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -22,6 +22,7 @@ import org.apache.avro.reflect.Nullable; +import org.apache.spark.TaskContext; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -84,7 +85,7 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private final PrefixComparators.RadixSortSupport radixSortSupport; /** - * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at + * Within this buffer, position {@code 2 * i} holds a pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. * * Only part of the array will be used to store the pointers, the rest part is preserved as @@ -253,6 +254,7 @@ public final class SortedIterator extends UnsafeSorterIterator implements Clonea private long keyPrefix; private int recordLength; private long currentPageNumber; + private final TaskContext taskContext = TaskContext.get(); private SortedIterator(int numRecords, int offset) { this.numRecords = numRecords; @@ -283,6 +285,14 @@ public boolean hasNext() { @Override public void loadNext() { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); + } // This pointer points to a 4-byte record length, followed by the record's bytes final long recordPointer = array.get(offset + position); currentPageNumber = TaskMemoryManager.decodePageNumber(recordPointer); @@ -322,7 +332,7 @@ public UnsafeSorterIterator getSortedIterator() { if (sortComparator != null) { if (this.radixSortSupport != null) { offset = RadixSort.sortKeyPrefixArray( - array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7, + array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7, radixSortSupport.sortDescending(), radixSortSupport.sortSigned()); } else { MemoryBlock unused = new MemoryBlock( diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index 430bf677edbdf..d9f84d10e9051 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -25,7 +25,7 @@ * Supports sorting an array of (record pointer, key prefix) pairs. * Used in {@link UnsafeInMemorySorter}. *

- * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at + * Within each long[] buffer, position {@code 2 * i} holds a pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ public final class UnsafeSortDataFormat diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 01aed95878cf6..cf4dfde86ca91 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -27,22 +27,18 @@ final class UnsafeSorterSpillMerger { private final PriorityQueue priorityQueue; UnsafeSorterSpillMerger( - final RecordComparator recordComparator, - final PrefixComparator prefixComparator, - final int numSpills) { - final Comparator comparator = new Comparator() { - - @Override - public int compare(UnsafeSorterIterator left, UnsafeSorterIterator right) { - final int prefixComparisonResult = - prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); - if (prefixComparisonResult == 0) { - return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); - } else { - return prefixComparisonResult; - } + RecordComparator recordComparator, + PrefixComparator prefixComparator, + int numSpills) { + Comparator comparator = (left, right) -> { + int prefixComparisonResult = + prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); + if (prefixComparisonResult == 0) { + return recordComparator.compare( + left.getBaseObject(), left.getBaseOffset(), + right.getBaseObject(), right.getBaseOffset()); + } else { + return prefixComparisonResult; } }; priorityQueue = new PriorityQueue<>(numSpills, comparator); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index a658e5eb47b78..9521ab86a12d5 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -23,6 +23,7 @@ import com.google.common.io.Closeables; import org.apache.spark.SparkEnv; +import org.apache.spark.TaskContext; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.spark.serializer.SerializerManager; import org.apache.spark.storage.BlockId; @@ -51,6 +52,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen private byte[] arr = new byte[1024 * 1024]; private Object baseObject = arr; private final long baseOffset = Platform.BYTE_ARRAY_OFFSET; + private final TaskContext taskContext = TaskContext.get(); public UnsafeSorterSpillReader( SerializerManager serializerManager, @@ -94,6 +96,14 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + // Kill the task in case it has been marked as killed. This logic is from + // InterruptibleIterator, but we inline it here instead of wrapping the iterator in order + // to avoid performance overhead. This check is added here in `loadNext()` instead of in + // `hasNext()` because it's technically possible for the caller to be relying on + // `getNumRecords()` instead of `hasNext()` to know when to stop. + if (taskContext != null) { + taskContext.killTaskIfInterrupted(); + } recordLength = din.readInt(); keyPrefix = din.readLong(); if (recordLength > arr.length) { diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 89a7963a86d98..277010015072a 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -36,3 +36,7 @@ log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR + +# Parquet related logging +log4j.logger.org.apache.parquet.CorruptStatistics=ERROR +log4j.logger.parquet.CorruptStatistics=ERROR diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html index 64ea719141f4b..5c91304e49fd7 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage-template.html @@ -24,7 +24,15 @@

Summary

RDD Blocks Storage Memory + title="Memory used / total available memory for storage of data like RDD partitions cached in memory.">Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores @@ -45,6 +53,11 @@

Summary

title="Bytes and records written to disk in order to be read by a shuffle in a future stage."> Shuffle Write + + + Blacklisted + @@ -68,6 +81,14 @@

Executors

Storage Memory + + + On Heap Storage Memory + + + Off Heap Storage Memory Disk Used Cores Active Tasks diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 1df67337ea031..6643a8f361cdc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -182,7 +182,7 @@ $(document).ready(function () { executorsSummary = $("#active-executors"); getStandAloneppId(function (appId) { - + var endPoint = createRESTEndPoint(appId); $.getJSON(endPoint, function (response, status, jqXHR) { var summary = []; @@ -190,6 +190,10 @@ $(document).ready(function () { var allRDDBlocks = 0; var allMemoryUsed = 0; var allMaxMemory = 0; + var allOnHeapMemoryUsed = 0; + var allOnHeapMaxMemory = 0; + var allOffHeapMemoryUsed = 0; + var allOffHeapMaxMemory = 0; var allDiskUsed = 0; var allTotalCores = 0; var allMaxTasks = 0; @@ -202,11 +206,16 @@ $(document).ready(function () { var allTotalInputBytes = 0; var allTotalShuffleRead = 0; var allTotalShuffleWrite = 0; - + var allTotalBlacklisted = 0; + var activeExecCnt = 0; var activeRDDBlocks = 0; var activeMemoryUsed = 0; var activeMaxMemory = 0; + var activeOnHeapMemoryUsed = 0; + var activeOnHeapMaxMemory = 0; + var activeOffHeapMemoryUsed = 0; + var activeOffHeapMaxMemory = 0; var activeDiskUsed = 0; var activeTotalCores = 0; var activeMaxTasks = 0; @@ -219,11 +228,16 @@ $(document).ready(function () { var activeTotalInputBytes = 0; var activeTotalShuffleRead = 0; var activeTotalShuffleWrite = 0; - + var activeTotalBlacklisted = 0; + var deadExecCnt = 0; var deadRDDBlocks = 0; var deadMemoryUsed = 0; var deadMaxMemory = 0; + var deadOnHeapMemoryUsed = 0; + var deadOnHeapMaxMemory = 0; + var deadOffHeapMemoryUsed = 0; + var deadOffHeapMaxMemory = 0; var deadDiskUsed = 0; var deadTotalCores = 0; var deadMaxTasks = 0; @@ -236,12 +250,28 @@ $(document).ready(function () { var deadTotalInputBytes = 0; var deadTotalShuffleRead = 0; var deadTotalShuffleWrite = 0; - + var deadTotalBlacklisted = 0; + + response.forEach(function (exec) { + var memoryMetrics = { + usedOnHeapStorageMemory: 0, + usedOffHeapStorageMemory: 0, + totalOnHeapStorageMemory: 0, + totalOffHeapStorageMemory: 0 + }; + + exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics; + }); + response.forEach(function (exec) { allExecCnt += 1; allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; + allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -254,11 +284,16 @@ $(document).ready(function () { allTotalInputBytes += exec.totalInputBytes; allTotalShuffleRead += exec.totalShuffleRead; allTotalShuffleWrite += exec.totalShuffleWrite; + allTotalBlacklisted += exec.isBlacklisted ? 1 : 0; if (exec.isActive) { activeExecCnt += 1; activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; + activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -271,11 +306,16 @@ $(document).ready(function () { activeTotalInputBytes += exec.totalInputBytes; activeTotalShuffleRead += exec.totalShuffleRead; activeTotalShuffleWrite += exec.totalShuffleWrite; + activeTotalBlacklisted += exec.isBlacklisted ? 1 : 0; } else { deadExecCnt += 1; deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; + deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -288,14 +328,19 @@ $(document).ready(function () { deadTotalInputBytes += exec.totalInputBytes; deadTotalShuffleRead += exec.totalShuffleRead; deadTotalShuffleWrite += exec.totalShuffleWrite; + deadTotalBlacklisted += exec.isBlacklisted ? 1 : 0; } }); - + var totalSummary = { "execCnt": ( "Total(" + allExecCnt + ")"), "allRDDBlocks": allRDDBlocks, "allMemoryUsed": allMemoryUsed, "allMaxMemory": allMaxMemory, + "allOnHeapMemoryUsed": allOnHeapMemoryUsed, + "allOnHeapMaxMemory": allOnHeapMaxMemory, + "allOffHeapMemoryUsed": allOffHeapMemoryUsed, + "allOffHeapMaxMemory": allOffHeapMaxMemory, "allDiskUsed": allDiskUsed, "allTotalCores": allTotalCores, "allMaxTasks": allMaxTasks, @@ -307,13 +352,18 @@ $(document).ready(function () { "allTotalGCTime": allTotalGCTime, "allTotalInputBytes": allTotalInputBytes, "allTotalShuffleRead": allTotalShuffleRead, - "allTotalShuffleWrite": allTotalShuffleWrite + "allTotalShuffleWrite": allTotalShuffleWrite, + "allTotalBlacklisted": allTotalBlacklisted }; var activeSummary = { "execCnt": ( "Active(" + activeExecCnt + ")"), "allRDDBlocks": activeRDDBlocks, "allMemoryUsed": activeMemoryUsed, "allMaxMemory": activeMaxMemory, + "allOnHeapMemoryUsed": activeOnHeapMemoryUsed, + "allOnHeapMaxMemory": activeOnHeapMaxMemory, + "allOffHeapMemoryUsed": activeOffHeapMemoryUsed, + "allOffHeapMaxMemory": activeOffHeapMaxMemory, "allDiskUsed": activeDiskUsed, "allTotalCores": activeTotalCores, "allMaxTasks": activeMaxTasks, @@ -325,13 +375,18 @@ $(document).ready(function () { "allTotalGCTime": activeTotalGCTime, "allTotalInputBytes": activeTotalInputBytes, "allTotalShuffleRead": activeTotalShuffleRead, - "allTotalShuffleWrite": activeTotalShuffleWrite + "allTotalShuffleWrite": activeTotalShuffleWrite, + "allTotalBlacklisted": activeTotalBlacklisted }; var deadSummary = { "execCnt": ( "Dead(" + deadExecCnt + ")" ), "allRDDBlocks": deadRDDBlocks, "allMemoryUsed": deadMemoryUsed, "allMaxMemory": deadMaxMemory, + "allOnHeapMemoryUsed": deadOnHeapMemoryUsed, + "allOnHeapMaxMemory": deadOnHeapMaxMemory, + "allOffHeapMemoryUsed": deadOffHeapMemoryUsed, + "allOffHeapMaxMemory": deadOffHeapMaxMemory, "allDiskUsed": deadDiskUsed, "allTotalCores": deadTotalCores, "allMaxTasks": deadMaxTasks, @@ -343,12 +398,13 @@ $(document).ready(function () { "allTotalGCTime": deadTotalGCTime, "allTotalInputBytes": deadTotalInputBytes, "allTotalShuffleRead": deadTotalShuffleRead, - "allTotalShuffleWrite": deadTotalShuffleWrite + "allTotalShuffleWrite": deadTotalShuffleWrite, + "allTotalBlacklisted": deadTotalBlacklisted }; - + var data = {executors: response, "execSummary": [activeSummary, deadSummary, totalSummary]}; $.get(createTemplateURI(appId), function (template) { - + executorsSummary.append(Mustache.render($(template).filter("#executors-summary-template").html(), data)); var selector = "#active-executors-table"; var conf = { @@ -360,11 +416,44 @@ $(document).ready(function () { } }, {data: 'hostPort'}, - {data: 'isActive', render: formatStatus}, + {data: 'isActive', render: function (data, type, row) { + if (type !== 'display') return data; + if (row.isBlacklisted) return "Blacklisted"; + else return formatStatus (data, type); + } + }, {data: 'rddBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.memoryUsed, type) + ' / ' + formatBytes(row.maxMemory, type)) : row.memoryUsed; + if (type !== 'display') + return row.memoryUsed; + else + return (formatBytes(row.memoryUsed, type) + ' / ' + + formatBytes(row.maxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.memoryMetrics.usedOnHeapStorageMemory; + else + return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.memoryMetrics.usedOffHeapStorageMemory; + else + return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'diskUsed', render: formatBytes}, @@ -403,27 +492,20 @@ $(document).ready(function () { {data: 'totalInputBytes', render: formatBytes}, {data: 'totalShuffleRead', render: formatBytes}, {data: 'totalShuffleWrite', render: formatBytes}, - {data: 'executorLogs', render: formatLogsCells}, + {name: 'executorLogsCol', data: 'executorLogs', render: formatLogsCells}, { + name: 'threadDumpCol', data: 'id', render: function (data, type) { return type === 'display' ? ("Thread Dump" ) : data; } } ], - "columnDefs": [ - { - "targets": [ 15 ], - "visible": logsExist(response) - }, - { - "targets": [ 16 ], - "visible": getThreadDumpEnabled() - } - ], "order": [[0, "asc"]] }; - $(selector).DataTable(conf); + var dt = $(selector).DataTable(conf); + dt.column('executorLogsCol:name').visible(logsExist(response)); + dt.column('threadDumpCol:name').visible(getThreadDumpEnabled()); $('#active-executors [data-toggle="tooltip"]').tooltip(); var sumSelector = "#summary-execs-table"; @@ -439,7 +521,35 @@ $(document).ready(function () { {data: 'allRDDBlocks'}, { data: function (row, type) { - return type === 'display' ? (formatBytes(row.allMemoryUsed, type) + ' / ' + formatBytes(row.allMaxMemory, type)) : row.allMemoryUsed; + if (type !== 'display') + return row.allMemoryUsed + else + return (formatBytes(row.allMemoryUsed, type) + ' / ' + + formatBytes(row.allMaxMemory, type)); + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOnHeapMemoryUsed; + else + return (formatBytes(row.allOnHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOnHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('on_heap_memory') + } + }, + { + data: function (row, type) { + if (type !== 'display') + return row.allOffHeapMemoryUsed; + else + return (formatBytes(row.allOffHeapMemoryUsed, type) + ' / ' + + formatBytes(row.allOffHeapMaxMemory, type)); + }, + "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { + $(nTd).addClass('off_heap_memory') } }, {data: 'allDiskUsed', render: formatBytes}, @@ -477,7 +587,8 @@ $(document).ready(function () { }, {data: 'allTotalInputBytes', render: formatBytes}, {data: 'allTotalShuffleRead', render: formatBytes}, - {data: 'allTotalShuffleWrite', render: formatBytes} + {data: 'allTotalShuffleWrite', render: formatBytes}, + {data: 'allTotalBlacklisted'} ], "paging": false, "searching": false, diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js new file mode 100644 index 0000000000000..55d540d8317a0 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-common.js @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +$(document).ready(function() { + if ($('#last-updated').length) { + var lastUpdatedMillis = Number($('#last-updated').text()); + var updatedDate = new Date(lastUpdatedMillis); + $('#last-updated').text(updatedDate.toLocaleDateString()+", "+updatedDate.toLocaleTimeString()) + } +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 1fd6ef4a71253..6ba3b092dc658 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -68,16 +68,16 @@ {{#applications}} - {{id}} + {{id}} {{name}} {{#attempts}} - {{attemptId}} + {{attemptId}} {{startTime}} {{endTime}} {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 2a32e18672a22..1f89306403cd5 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -78,6 +78,12 @@ jQuery.extend( jQuery.fn.dataTableExt.oSort, { } } ); +jQuery.extend( jQuery.fn.dataTableExt.ofnSearch, { + "appid-numeric": function ( a ) { + return a.replace(/[\r\n]/g, " ").replace(/<.*?>/g, ""); + } +} ); + $(document).ajaxStop($.unblockUI); $(document).ajaxStart(function(){ $.blockUI({ message: '

Loading history summary...

'}); @@ -114,12 +120,19 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } } - var data = {"applications": array} + var data = { + "uiroot": uiRoot, + "applications": array + } + $.get("static/historypage-template.html", function(template) { historySummary.append(Mustache.render($(template).filter("#history-summary-template").html(),data)); var selector = "#history-summary-table"; @@ -135,6 +148,9 @@ $(document).ready(function() { {name: 'eighth'}, {name: 'ninth'}, ], + "columnDefs": [ + {"searchable": false, "targets": [5]} + ], "autoWidth": false, "order": [[ 4, "desc" ]] }; diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js index 1782b4f209c09..b5c43e5788bc3 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/log-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -51,13 +51,26 @@ function noNewAlert() { window.setTimeout(function () {alert.css("display", "none");}, 4000); } + +function getRESTEndPoint() { + // If the worker is served from the master through a proxy (see doc on spark.ui.reverseProxy), + // we need to retain the leading ../proxy// part of the URL when making REST requests. + // Similar logic is contained in executorspage.js function createRESTEndPoint. + var words = document.baseURI.split('/'); + var ind = words.indexOf("proxy"); + if (ind > 0) { + return words.slice(0, ind + 2).join('/') + "/log"; + } + return "/log" +} + function loadMore() { var offset = Math.max(startByte - byteLength, 0); var moreByteLength = Math.min(byteLength, startByte); $.ajax({ type: "GET", - url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + url: getRESTEndPoint() + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, success: function (data) { var oldHeight = $(".log-content")[0].scrollHeight; var newlineIndex = data.indexOf('\n'); @@ -83,14 +96,14 @@ function loadMore() { function loadNew() { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=0", + url: getRESTEndPoint() + baseParams + "&byteLength=0", success: function (data) { var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); var newDataLen = dataInfo[2] - totalLogLength; if (newDataLen != 0) { $.ajax({ type: "GET", - url: "/log" + baseParams + "&byteLength=" + newDataLen, + url: getRESTEndPoint() + baseParams + "&byteLength=" + newDataLen, success: function (data) { var newlineIndex = data.indexOf('\n'); var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js index ff241470f32df..9960d5c34d1fc 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js +++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js @@ -207,8 +207,8 @@ sorttable = { hasInputs = (typeof node.getElementsByTagName == 'function') && node.getElementsByTagName('input').length; - - if (node.getAttribute("sorttable_customkey") != null) { + + if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) { return node.getAttribute("sorttable_customkey"); } else if (typeof node.textContent != 'undefined' && !hasInputs) { diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 1b0d4692d9cd0..75b959fdeb59a 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -35,7 +35,7 @@ * primitives (e.g. take, any SQL query). * * In the visualization, an RDD is expressed as a node, and its dependencies - * as directed edges (from parent to child). operation scopes, stages, and + * as directed edges (from parent to child). Operation scopes, stages, and * jobs are expressed as clusters that may contain one or many nodes. These * clusters may be nested inside of each other in the scenarios described * above. @@ -173,6 +173,7 @@ function renderDagViz(forJob) { }); resizeSvg(svg); + interpretLineBreak(svg); } /* Render the RDD DAG visualization on the stage page. */ @@ -362,6 +363,27 @@ function resizeSvg(svg) { .attr("height", height); } +/* + * Helper function to interpret line break for tag 'tspan'. + * For tag 'tspan', line break '/n' is display in UI as raw for both stage page and job page, + * here this function is to enable line break. + */ +function interpretLineBreak(svg) { + var allTSpan = svg.selectAll("tspan").each(function() { + node = d3.select(this); + var original = node[0][0].innerHTML; + if (original.indexOf("\\n") != -1) { + var arr = original.split("\\n"); + var newNode = this.cloneNode(this); + + node[0][0].innerHTML = arr[0]; + newNode.innerHTML = arr[1]; + + this.parentNode.appendChild(newNode); + } + }); +} + /* * (Job page only) Helper function to draw edges that cross stage boundaries. * We need to do this manually because we render each stage separately in dagre-d3. @@ -470,15 +492,23 @@ function connectRDDs(fromRDDId, toRDDId, edgesContainer, svgContainer) { edgesContainer.append("path").datum(points).attr("d", line); } +/* + * Replace `/n` with `
` + */ +function replaceLineBreak(str) { + return str.replace("\\n", "
"); +} + /* (Job page only) Helper function to add tooltips for RDDs. */ function addTooltipsForRDDs(svgContainer) { svgContainer.selectAll("g.node").each(function() { var node = d3.select(this); - var tooltipText = node.attr("name"); + var tooltipText = replaceLineBreak(node.attr("name")); if (tooltipText) { node.select("circle") .attr("data-toggle", "tooltip") .attr("data-placement", "bottom") + .attr("data-html", "true") // to interpret line break, tooltipText is showing title .attr("title", tooltipText); } // Link tooltips for all nodes that belong to the same RDD diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js index 14b06bfe860ed..0315ebf5c48a9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/table.js +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -36,7 +36,7 @@ function toggleThreadStackTrace(threadId, forceAdd) { if (stackTrace.length == 0) { var stackTraceText = $('#' + threadId + "_td_stacktrace").html() var threadCell = $("#thread_" + threadId + "_tr") - threadCell.after("
" +
+        threadCell.after("
" +
             stackTraceText +  "
") } else { if (!forceAdd) { @@ -73,6 +73,7 @@ function onMouseOverAndOut(threadId) { $("#" + threadId + "_td_id").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_name").toggleClass("threaddump-td-mouseover"); $("#" + threadId + "_td_state").toggleClass("threaddump-td-mouseover"); + $("#" + threadId + "_td_locking").toggleClass("threaddump-td-mouseover"); } function onSearchStringChange() { diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b157f3e0a407d..935d9b1aec615 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -205,7 +205,8 @@ span.additional-metric-title { /* Hide all additional metrics by default. This is done here rather than using JavaScript to * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ .scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote, -.serialization_time, .getting_result_time, .peak_execution_memory { +.serialization_time, .getting_result_time, .peak_execution_memory, +.on_heap_memory, .off_heap_memory { display: none; } @@ -246,4 +247,8 @@ a.expandbutton { text-align: center; margin: 0; padding: 4px 0; +} + +.table-cell-width-limited td { + max-width: 600px; } \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index e37307aa1f705..0fa1fcf25f8b9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -15,6 +15,12 @@ * limitations under the License. */ +var uiRoot = ""; + +function setUIRoot(val) { + uiRoot = val; +} + function collapseTablePageLoad(name, table){ if (window.localStorage.getItem(name) == "true") { // Set it to false so that the click function can revert it diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala b/core/src/main/scala/org/apache/spark/Accumulator.scala index 9d1f1d59dbce1..9d5fbefc824ad 100644 --- a/core/src/main/scala/org/apache/spark/Accumulator.scala +++ b/core/src/main/scala/org/apache/spark/Accumulator.scala @@ -24,9 +24,8 @@ package org.apache.spark * They can be used to implement counters (as in MapReduce) or sums. Spark natively supports * accumulators of numeric value types, and programmers can add support for new types. * - * An accumulator is created from an initial value `v` by calling - * [[SparkContext#accumulator SparkContext.accumulator]]. - * Tasks running on the cluster can then add to it using the [[Accumulable#+= +=]] operator. + * An accumulator is created from an initial value `v` by calling `SparkContext.accumulator`. + * Tasks running on the cluster can then add to it using the `+=` operator. * However, they cannot read its value. Only the driver program can read the accumulator's value, * using its [[#value]] method. * diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 5678d790e9e76..4d884dec07916 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,7 +18,8 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} +import java.util.Collections +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, ScheduledExecutorService, TimeUnit} import scala.collection.JavaConverters._ @@ -58,7 +59,12 @@ private class CleanupTaskWeakReference( */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - private val referenceBuffer = new ConcurrentLinkedQueue[CleanupTaskWeakReference]() + /** + * A buffer to ensure that `CleanupTaskWeakReference`s are not garbage collected as long as they + * have not been handled by the reference queue. + */ + private val referenceBuffer = + Collections.newSetFromMap[CleanupTaskWeakReference](new ConcurrentHashMap) private val referenceQueue = new ReferenceQueue[AnyRef] @@ -139,7 +145,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { periodicGCService.shutdown() } - /** Register a RDD for cleanup when it is garbage collected. */ + /** Register an RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]): Unit = { registerForCleanup(rdd, CleanRDD(rdd.id)) } @@ -176,10 +182,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { .map(_.asInstanceOf[CleanupTaskWeakReference]) // Synchronize here to avoid being interrupted on stop() synchronized { - reference.map(_.task).foreach { task => - logDebug("Got cleaning task " + task) - referenceBuffer.remove(reference.get) - task match { + reference.foreach { ref => + logDebug("Got cleaning task " + ref.task) + referenceBuffer.remove(ref) + ref.task match { case CleanRDD(rddId) => doCleanupRDD(rddId, blocking = blockOnCleanupTasks) case CleanShuffle(shuffleId) => diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 5d47f624ac8a3..9112d93a86b2a 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -54,9 +54,27 @@ private[spark] trait ExecutorAllocationClient { /** * Request that the cluster manager kill the specified executors. + * + * When asking the executor to be replaced, the executor loss is considered a failure, and + * killed tasks that are running on the executor will count towards the failure limits. If no + * replacement is being requested, then the tasks will not count towards the limit. + * + * @param executorIds identifiers of executors to kill + * @param replace whether to replace the killed executors with new ones, default false + * @param force whether to force kill busy executors, default false * @return the ids of the executors acknowledged by the cluster manager to be removed. */ - def killExecutors(executorIds: Seq[String]): Seq[String] + def killExecutors( + executorIds: Seq[String], + replace: Boolean = false, + force: Boolean = false): Seq[String] + + /** + * Request that the cluster manager kill every executor on the specified host. + * + * @return whether the request is acknowledged by the cluster manager. + */ + def killExecutorsOnHost(host: String): Boolean /** * Request that the cluster manager kill the specified executor. diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 1366251d0618f..fcc72ff49276d 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 @@ -439,7 +439,7 @@ private[spark] class ExecutorAllocationManager( executorsRemoved } else { logWarning(s"Unable to reach the cluster manager to kill executor/s " + - "executorIdsToBeRemoved.mkString(\",\") or no executor eligible to kill!") + s"${executorIdsToBeRemoved.mkString(",")} or no executor eligible to kill!") Seq.empty[String] } } diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 5c262bcbddf76..7f2c0068174b5 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -33,11 +33,8 @@ class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator // is allowed. The assumption is that Thread.interrupted does not have a memory fence in read // (just a volatile field in C), while context.interrupted is a volatile in the JVM, which // introduces an expensive read fence. - if (context.isInterrupted) { - throw new TaskKilledException - } else { - delegate.hasNext - } + context.killTaskIfInterrupted() + delegate.hasNext } def next(): T = delegate.next() diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f8f0f513134f..4ef6656222455 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -99,7 +99,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected def askTracker[T: ClassTag](message: Any): T = { try { - trackerEndpoint.askWithRetry[T](message) + trackerEndpoint.askSync[T](message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -317,12 +317,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf, pool } - // Make sure that that we aren't going to exceed the max RPC message size by making sure + // Make sure that we aren't going to exceed the max RPC message size by making sure // we use broadcast to send large map output statuses. if (minSizeForBroadcast > maxRpcMessageSize) { val msg = s"spark.shuffle.mapOutput.minSizeForBroadcast ($minSizeForBroadcast bytes) must " + s"be <= spark.rpc.message.maxSize ($maxRpcMessageSize bytes) to prevent sending an rpc " + - "message that is to large." + "message that is too large." logError(msg) throw new IllegalArgumentException(msg) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 93dfbc0e6ed65..f83f5278e8b8f 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -101,7 +101,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. * - * Note that the actual number of partitions created by the RangePartitioner might not be the same + * @note The actual number of partitions created by the RangePartitioner might not be the same * as the `partitions` parameter, in the case where the number of sampled records is less than * the value of `partitions`. */ diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index be19179b00a49..29163e7f30546 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -34,6 +34,8 @@ import org.apache.spark.internal.Logging * * @param enabled enables or disables SSL; if it is set to false, the rest of the * settings are disregarded + * @param port the port where to bind the SSL server; if not defined, it will be + * based on the non-SSL port for the same service. * @param keyStore a path to the key-store file * @param keyStorePassword a password to access the key-store file * @param keyPassword a password to access the private key in the key-store @@ -47,6 +49,7 @@ import org.apache.spark.internal.Logging */ private[spark] case class SSLOptions( enabled: Boolean = false, + port: Option[Int] = None, keyStore: Option[File] = None, keyStorePassword: Option[String] = None, keyPassword: Option[String] = None, @@ -150,8 +153,8 @@ private[spark] object SSLOptions extends Logging { * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers * * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. + * + * Oracle blog page. * * You can optionally specify the default configuration. If you do, for each setting which is * missing in SparkConf, the corresponding setting is used from the default configuration. @@ -164,6 +167,11 @@ private[spark] object SSLOptions extends Logging { def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) + val port = conf.getOption(s"$ns.port").map(_.toInt) + port.foreach { p => + require(p >= 0, "Port number must be a non-negative value.") + } + val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_)) .orElse(defaults.flatMap(_.keyStore)) @@ -198,6 +206,7 @@ private[spark] object SSLOptions extends Logging { new SSLOptions( enabled, + port, keyStore, keyStorePassword, keyPassword, diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 199365ad925a3..2480e56b72ccf 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -21,19 +21,16 @@ import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} import java.security.{KeyStore, SecureRandom} import java.security.cert.X509Certificate -import javax.crypto.KeyGenerator import javax.net.ssl._ import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text -import org.apache.hadoop.security.Credentials import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.sasl.SecretKeyHolder -import org.apache.spark.security.CryptoStreamUtils._ import org.apache.spark.util.Utils /** @@ -185,7 +182,9 @@ import org.apache.spark.util.Utils * setting `spark.ssl.useNodeLocalConf` to `true`. */ -private[spark] class SecurityManager(sparkConf: SparkConf) +private[spark] class SecurityManager( + sparkConf: SparkConf, + val ioEncryptionKey: Option[Array[Byte]] = None) extends Logging with SecretKeyHolder { import SecurityManager._ @@ -193,7 +192,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) // allow all users/groups to have view/modify permissions private val WILDCARD_ACL = "*" - private val authOn = sparkConf.getBoolean(SecurityManager.SPARK_AUTH_CONF, false) + private val authOn = sparkConf.get(NETWORK_AUTH_ENABLED) // keep spark.ui.acls.enable for backwards compatibility with 1.0 private var aclsOn = sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false)) @@ -415,6 +414,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) logInfo("Changing acls enabled to: " + aclsOn) } + def getIOEncryptionKey(): Option[Array[Byte]] = ioEncryptionKey + /** * Generates or looks up the secret key. * @@ -516,11 +517,11 @@ private[spark] class SecurityManager(sparkConf: SparkConf) def isAuthenticationEnabled(): Boolean = authOn /** - * Checks whether SASL encryption should be enabled. - * @return Whether to enable SASL encryption when connecting to services that support it. + * Checks whether network encryption should be enabled. + * @return Whether to enable encryption when connecting to services that support it. */ - def isSaslEncryptionEnabled(): Boolean = { - sparkConf.getBoolean("spark.authenticate.enableSaslEncryption", false) + def isEncryptionEnabled(): Boolean = { + sparkConf.get(NETWORK_ENCRYPTION_ENABLED) || sparkConf.get(SASL_ENCRYPTION_ENABLED) } /** @@ -559,19 +560,4 @@ private[spark] object SecurityManager { // key used to store the spark secret in the Hadoop UGI val SECRET_LOOKUP_KEY = "sparkCookie" - /** - * Setup the cryptographic key used by IO encryption in credentials. The key is generated using - * [[KeyGenerator]]. The algorithm and key length is specified by the [[SparkConf]]. - */ - def initIOEncryptionKey(conf: SparkConf, credentials: Credentials): Unit = { - if (credentials.getSecretKey(SPARK_IO_TOKEN) == null) { - val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) - val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) - val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) - keyGen.init(keyLen) - - val ioKey = keyGen.generateKey() - credentials.addSecretKey(SPARK_IO_TOKEN, ioKey.getEncoded) - } - } } diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index c9c342df82c97..956724b14bba3 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -42,10 +42,10 @@ import org.apache.spark.util.Utils * All setter methods in this class support chaining. For example, you can write * `new SparkConf().setMaster("local").setAppName("My app")`. * - * Note that once a SparkConf object is passed to Spark, it is cloned and can no longer be modified - * by the user. Spark does not support modifying the configuration at runtime. - * * @param loadDefaults whether to also load values from Java system properties + * + * @note Once a SparkConf object is passed to Spark, it is cloned and can no longer be modified + * by the user. Spark does not support modifying the configuration at runtime. */ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Serializable { @@ -262,7 +262,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a time parameter as seconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then seconds are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException If the time parameter is not set */ def getTimeAsSeconds(key: String): Long = { Utils.timeStringAsSeconds(get(key)) @@ -279,7 +279,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a time parameter as milliseconds; throws a NoSuchElementException if it's not set. If no * suffix is provided then milliseconds are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException If the time parameter is not set */ def getTimeAsMs(key: String): Long = { Utils.timeStringAsMs(get(key)) @@ -296,7 +296,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as bytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then bytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException If the size parameter is not set */ def getSizeAsBytes(key: String): Long = { Utils.byteStringAsBytes(get(key)) @@ -320,7 +320,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Kibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Kibibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException If the size parameter is not set */ def getSizeAsKb(key: String): Long = { Utils.byteStringAsKb(get(key)) @@ -337,7 +337,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Mebibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Mebibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException If the size parameter is not set */ def getSizeAsMb(key: String): Long = { Utils.byteStringAsMb(get(key)) @@ -354,7 +354,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria /** * Get a size parameter as Gibibytes; throws a NoSuchElementException if it's not set. If no * suffix is provided then Gibibytes are assumed. - * @throws NoSuchElementException + * @throws java.util.NoSuchElementException If the size parameter is not set */ def getSizeAsGb(key: String): Long = { Utils.byteStringAsGb(get(key)) @@ -378,7 +378,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray } - /** Get all parameters that start with `prefix` */ + /** + * Get all parameters that start with `prefix` + */ def getAllWithPrefix(prefix: String): Array[(String, String)] = { getAll.filter { case (k, v) => k.startsWith(prefix) } .map { case (k, v) => (k.substring(prefix.length), v) } @@ -516,71 +518,6 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } } - // Check for legacy configs - sys.env.get("SPARK_JAVA_OPTS").foreach { value => - val warning = - s""" - |SPARK_JAVA_OPTS was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with conf/spark-defaults.conf to set defaults for an application - | - ./spark-submit with --driver-java-options to set -X options for a driver - | - spark.executor.extraJavaOptions to set -X options for executors - | - SPARK_DAEMON_JAVA_OPTS to set java options for standalone daemons (master or worker) - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorOptsKey, driverOptsKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_JAVA_OPTS. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - sys.env.get("SPARK_CLASSPATH").foreach { value => - val warning = - s""" - |SPARK_CLASSPATH was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --driver-class-path to augment the driver classpath - | - spark.executor.extraClassPath to augment the executor classpath - """.stripMargin - logWarning(warning) - - for (key <- Seq(executorClasspathKey, driverClassPathKey)) { - if (getOption(key).isDefined) { - throw new SparkException(s"Found both $key and SPARK_CLASSPATH. Use only the former.") - } else { - logWarning(s"Setting '$key' to '$value' as a work-around.") - set(key, value) - } - } - } - - if (!contains(sparkExecutorInstances)) { - sys.env.get("SPARK_WORKER_INSTANCES").foreach { value => - val warning = - s""" - |SPARK_WORKER_INSTANCES was detected (set to '$value'). - |This is deprecated in Spark 1.0+. - | - |Please instead use: - | - ./spark-submit with --num-executors to specify the number of executors - | - Or set SPARK_EXECUTOR_INSTANCES - | - spark.executor.instances to configure the number of instances in the spark config. - """.stripMargin - logWarning(warning) - - set("spark.executor.instances", value) - } - } - if (contains("spark.master") && get("spark.master").startsWith("yarn-")) { val warning = s"spark.master ${get("spark.master")} is deprecated in Spark 2.0+, please " + "instead use \"yarn\" with specified deploy mode." @@ -605,6 +542,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria "\"client\".") } } + + val encryptionEnabled = get(NETWORK_ENCRYPTION_ENABLED) || get(SASL_ENCRYPTION_ENABLED) + require(!encryptionEnabled || get(NETWORK_AUTH_ENABLED), + s"${NETWORK_AUTH_ENABLED.key} must be enabled when enabling encryption.") } /** @@ -638,7 +579,9 @@ private[spark] object SparkConf extends Logging { "are no longer accepted. To specify the equivalent now, one may use '64k'."), DeprecatedConfig("spark.rpc", "2.0", "Not used any more."), DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0", - "Please use the new blacklisting options, spark.blacklist.*") + "Please use the new blacklisting options, spark.blacklist.*"), + DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"), + DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) @@ -697,8 +640,10 @@ private[spark] object SparkConf extends Logging { "spark.rpc.message.maxSize" -> Seq( AlternateConfig("spark.akka.frameSize", "1.6")), "spark.yarn.jars" -> Seq( - AlternateConfig("spark.yarn.jar", "2.0")) - ) + AlternateConfig("spark.yarn.jar", "2.0")), + "spark.yarn.access.hadoopFileSystems" -> Seq( + AlternateConfig("spark.yarn.access.namenodes", "2.2")) + ) /** * A view of `configsWithAlternatives` that makes it more efficient to look up deprecated @@ -722,6 +667,7 @@ private[spark] object SparkConf extends Logging { (name.startsWith("spark.auth") && name != SecurityManager.SPARK_AUTH_SECRET_CONF) || name.startsWith("spark.ssl") || name.startsWith("spark.rpc") || + name.startsWith("spark.network") || isSparkPortConf(name) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4694790c72cd8..7dbceb9c5c1a3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor -import java.net.{MalformedURLException, URI} +import java.net.URI import java.util.{Arrays, Locale, Properties, ServiceLoader, UUID} import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -183,6 +183,8 @@ class SparkContext(config: SparkConf) extends Logging { // log out Spark Version in Spark driver log logInfo(s"Running Spark version $SPARK_VERSION") + warnDeprecatedVersions() + /* ------------------------------------------------------------------------------------- * | Private variables. These variables keep the internal state of the context, and are | | not accessible by the outside world. They're mutable since we want to initialize all | @@ -279,7 +281,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration: Configuration = _hadoopConfiguration @@ -346,13 +348,20 @@ class SparkContext(config: SparkConf) extends Logging { value } + private def warnDeprecatedVersions(): Unit = { + val javaVersion = System.getProperty("java.version").split("[+.\\-]+", 3) + if (scala.util.Properties.releaseVersion.exists(_.startsWith("2.10"))) { + logWarning("Support for Scala 2.10 is deprecated as of Spark 2.1.0") + } + } + /** Control our logLevel. This overrides any user-defined log settings. * @param logLevel The desired log level as a string. * Valid log levels include: ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN */ def setLogLevel(logLevel: String) { // let's allow lowercase or mixed case too - val upperCased = logLevel.toUpperCase(Locale.ENGLISH) + val upperCased = logLevel.toUpperCase(Locale.ROOT) require(SparkContext.VALID_LOG_LEVELS.contains(upperCased), s"Supplied level $logLevel did not match one of:" + s" ${SparkContext.VALID_LOG_LEVELS.mkString(",")}") @@ -370,6 +379,9 @@ class SparkContext(config: SparkConf) extends Logging { throw new SparkException("An application name must be set in your configuration") } + // log out spark.app.name in the Spark driver logs + logInfo(s"Submitted application: $appName") + // System property spark.yarn.app.id must be set if user code ran by AM on a YARN cluster if (master == "yarn" && deployMode == "cluster" && !_conf.contains("spark.yarn.app.id")) { throw new SparkException("Detected yarn cluster mode, but isn't running on a cluster. " + @@ -410,10 +422,6 @@ class SparkContext(config: SparkConf) extends Logging { } if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true") - if (_conf.get(IO_ENCRYPTION_ENABLED) && !SparkHadoopUtil.get.isYarnMode()) { - throw new SparkException("IO encryption is only supported in YARN mode, please disable it " + - s"by setting ${IO_ENCRYPTION_ENABLED.key} to false") - } // "_jobProgressListener" should be set up before creating SparkEnv because when creating // "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them. @@ -597,7 +605,7 @@ class SparkContext(config: SparkConf) extends Logging { Some(Utils.getThreadDump()) } else { val endpointRef = env.blockManager.master.getExecutorEndpointRef(executorId).get - Some(endpointRef.askWithRetry[Array[ThreadStackTrace]](TriggerThreadDump)) + Some(endpointRef.askSync[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -633,7 +641,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.SparkContext.setLocalProperty]]. + * `org.apache.spark.SparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = Option(localProperties.get).map(_.getProperty(key)).orNull @@ -651,7 +659,7 @@ class SparkContext(config: SparkConf) extends Logging { * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all + * The application can also use `org.apache.spark.SparkContext.cancelJobGroup` to cancel all * running jobs in this group. For example, * {{{ * // In the main thread: @@ -662,10 +670,10 @@ class SparkContext(config: SparkConf) extends Logging { * sc.cancelJobGroup("some_job_to_cancel") * }}} * - * If interruptOnCancel is set to true for the job group, then job cancellation will result - * in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure - * that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, - * where HDFS may respond to Thread.interrupt() by marking nodes as dead. + * @param interruptOnCancel If true, then job cancellation will result in `Thread.interrupt()` + * being called on the job's executor threads. This is useful to help ensure that the tasks + * are actually stopped in a timely manner, but is off by default due to HDFS-1208, where HDFS + * may respond to Thread.interrupt() by marking nodes as dead. */ def setJobGroup(groupId: String, description: String, interruptOnCancel: Boolean = false) { setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description) @@ -688,7 +696,7 @@ class SparkContext(config: SparkConf) extends Logging { * Execute a block of code in a scope such that all new RDDs created in this body will * be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}. * - * Note: Return statements are NOT allowed in the given body. + * @note Return statements are NOT allowed in the given body. */ private[spark] def withScope[U](body: => U): U = RDDOperationScope.withScope[U](this)(body) @@ -701,6 +709,9 @@ class SparkContext(config: SparkConf) extends Logging { * modified collection. Pass a copy of the argument to avoid this. * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions. + * @param seq Scala collection to distribute + * @param numSlices number of partitions to divide the collection into + * @return RDD representing distributed collection */ def parallelize[T: ClassTag]( seq: Seq[T], @@ -718,8 +729,8 @@ class SparkContext(config: SparkConf) extends Logging { * @param start the start value. * @param end the end value. * @param step the incremental step - * @param numSlices the partition number of the new RDD. - * @return + * @param numSlices number of partitions to divide the collection into + * @return RDD representing distributed range */ def range( start: Long, @@ -784,6 +795,9 @@ class SparkContext(config: SparkConf) extends Logging { /** Distribute a local Scala collection to form an RDD. * * This method is identical to `parallelize`. + * @param seq Scala collection to distribute + * @param numSlices number of partitions to divide the collection into + * @return RDD representing distributed collection */ def makeRDD[T: ClassTag]( seq: Seq[T], @@ -795,6 +809,8 @@ class SparkContext(config: SparkConf) extends Logging { * Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. + * @param seq list of tuples of data and location preferences (hostnames of Spark nodes) + * @return RDD representing data partitioned according to location preferences */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() @@ -805,6 +821,9 @@ class SparkContext(config: SparkConf) extends Logging { /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings. + * @param path path to the text file on a supported file system + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of lines of the text file */ def textFile( path: String, @@ -840,10 +859,13 @@ class SparkContext(config: SparkConf) extends Logging { * @note Small files are preferred, large file is also allowable, but may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * @note Partitioning is determined by data locality. This may result in too few partitions + * by default. * * @param path Directory to the input data files, the path can be comma separated paths as the * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. + * @return RDD representing tuples of file path and the corresponding file content */ def wholeTextFiles( path: String, @@ -889,10 +911,13 @@ class SparkContext(config: SparkConf) extends Logging { * @note Small files are preferred; very large files may cause bad performance. * @note On some filesystems, `.../path/*` can be a more efficient way to read all files * in a directory rather than `.../path/` or `.../path` + * @note Partitioning is determined by data locality. This may result in too few partitions + * by default. * * @param path Directory to the input data files, the path can be comma separated paths as the * list of inputs. * @param minPartitions A suggestion value of the minimal splitting number for input data. + * @return RDD representing tuples of file path and corresponding file content */ def binaryFiles( path: String, @@ -915,7 +940,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Load data from a flat binary file, assuming the length of each record is constant. * - * '''Note:''' We ensure that the byte array for each record in the resulting RDD + * @note We ensure that the byte array for each record in the resulting RDD * has the provided record length. * * @param path Directory to the input data files, the path can be comma separated paths as the @@ -936,12 +961,11 @@ class SparkContext(config: SparkConf) extends Logging { classOf[LongWritable], classOf[BytesWritable], conf = conf) - val data = br.map { case (k, v) => - val bytes = v.getBytes + br.map { case (k, v) => + val bytes = v.copyBytes() assert(bytes.length == recordLength, "Byte array does not have correct length") bytes } - data } /** @@ -953,12 +977,13 @@ class SparkContext(config: SparkConf) extends Logging { * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make * sure you won't modify the conf. A safe approach is always creating a new conf for * a new RDD. - * @param inputFormatClass Class of the InputFormat - * @param keyClass Class of the keys - * @param valueClass Class of the values + * @param inputFormatClass storage format of the data to be read + * @param keyClass `Class` of the key associated with the `inputFormatClass` parameter + * @param valueClass `Class` of the value associated with the `inputFormatClass` parameter * @param minPartitions Minimum number of Hadoop Splits to generate. + * @return RDD of tuples of key and corresponding value * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -983,11 +1008,18 @@ class SparkContext(config: SparkConf) extends Logging { /** Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param inputFormatClass storage format of the data to be read + * @param keyClass `Class` of the key associated with the `inputFormatClass` parameter + * @param valueClass `Class` of the value associated with the `inputFormatClass` parameter + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def hadoopFile[K, V]( path: String, @@ -1022,11 +1054,15 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minPartitions) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def hadoopFile[K, V, F <: InputFormat[K, V]] (path: String, minPartitions: Int) @@ -1046,18 +1082,37 @@ class SparkContext(config: SparkConf) extends Logging { * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) * }}} * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths as + * a list of inputs + * @return RDD of tuples of key and corresponding value */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { hadoopFile[K, V, F](path, defaultMinPartitions) } - /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ + /** + * Smarter version of `newApiHadoopFile` that uses class tags to figure out the classes of keys, + * values and the `org.apache.hadoop.mapreduce.InputFormat` (new MapReduce API) so that user + * don't need to pass them directly. Instead, callers can just write, for example: + * ``` + * val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path) + * ``` + * + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @return RDD of tuples of key and corresponding value + */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]] (path: String) (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = withScope { @@ -1072,11 +1127,18 @@ class SparkContext(config: SparkConf) extends Logging { * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param fClass storage format of the data to be read + * @param kClass `Class` of the key associated with the `fClass` parameter + * @param vClass `Class` of the value associated with the `fClass` parameter + * @param conf Hadoop configuration + * @return RDD of tuples of key and corresponding value */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( path: String, @@ -1108,11 +1170,11 @@ class SparkContext(config: SparkConf) extends Logging { * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make * sure you won't modify the conf. A safe approach is always creating a new conf for * a new RDD. - * @param fClass Class of the InputFormat - * @param kClass Class of the keys - * @param vClass Class of the values + * @param fClass storage format of the data to be read + * @param kClass `Class` of the key associated with the `fClass` parameter + * @param vClass `Class` of the value associated with the `fClass` parameter * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first @@ -1138,11 +1200,17 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param keyClass `Class` of the key associated with `SequenceFileInputFormat` + * @param valueClass `Class` of the value associated with `SequenceFileInputFormat` + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def sequenceFile[K, V](path: String, keyClass: Class[K], @@ -1157,11 +1225,16 @@ class SparkContext(config: SparkConf) extends Logging { /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param keyClass `Class` of the key associated with `SequenceFileInputFormat` + * @param valueClass `Class` of the value associated with `SequenceFileInputFormat` + * @return RDD of tuples of key and corresponding value */ def sequenceFile[K, V]( path: String, @@ -1187,11 +1260,15 @@ class SparkContext(config: SparkConf) extends Logging { * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle * operation will create many references to the same object. * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first * copy them using a `map` function. + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD of tuples of key and corresponding value */ def sequenceFile[K, V] (path: String, minPartitions: Int = defaultMinPartitions) @@ -1216,6 +1293,11 @@ class SparkContext(config: SparkConf) extends Logging { * be pretty slow if you use the default serializer (Java serialization), * though the nice thing about it is that there's very little effort required to save arbitrary * objects. + * + * @param path directory to the input data files, the path can be comma separated paths + * as a list of inputs + * @param minPartitions suggested minimum number of partitions for the resulting RDD + * @return RDD representing deserialized data from the file(s) */ def objectFile[T: ClassTag]( path: String, @@ -1268,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1297,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1318,19 +1400,21 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Register the given accumulator. Note that accumulators must be registered before use, or it - * will throw exception. + * Register the given accumulator. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _]): Unit = { acc.register(this) } /** - * Register the given accumulator with given name. Note that accumulators must be registered - * before use, or it will throw exception. + * Register the given accumulator with given name. + * + * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** @@ -1370,7 +1454,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * Create and register a `CollectionAccumulator`, which starts with empty list and accumulates * inputs by adding them into the list. */ def collectionAccumulator[T]: CollectionAccumulator[T] = { @@ -1380,7 +1464,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Create and register a [[CollectionAccumulator]], which starts with empty list and accumulates + * Create and register a `CollectionAccumulator`, which starts with empty list and accumulates * inputs by adding them into the list. */ def collectionAccumulator[T](name: String): CollectionAccumulator[T] = { @@ -1393,6 +1477,9 @@ class SparkContext(config: SparkConf) extends Logging { * Broadcast a read-only variable to the cluster, returning a * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. + * + * @param value value to broadcast to the Spark nodes + * @return `Broadcast` object, a read-only variable cached on each machine */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { assertNotStopped() @@ -1407,8 +1494,9 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * use `SparkFiles.get(fileName)` to find its download location. */ def addFile(path: String): Unit = { @@ -1422,12 +1510,12 @@ class SparkContext(config: SparkConf) extends Logging { /** * Add a file to be downloaded with this Spark job on every node. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, - * use `SparkFiles.get(fileName)` to find its download location. * - * A directory can be given if the recursive option is set to true. Currently directories are only - * supported for Hadoop-supported filesystems. + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported + * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, + * use `SparkFiles.get(fileName)` to find its download location. + * @param recursive if true, a directory can be given in `path`. Currently directories are + * only supported for Hadoop-supported filesystems. */ def addFile(path: String, recursive: Boolean): Unit = { val uri = new Path(path).toUri @@ -1479,6 +1567,15 @@ class SparkContext(config: SparkConf) extends Logging { listenerBus.addListener(listener) } + /** + * :: DeveloperApi :: + * Deregister the listener from Spark's listener bus. + */ + @DeveloperApi + def removeSparkListener(listener: SparkListenerInterface): Unit = { + listenerBus.removeListener(listener) + } + private[spark] def getExecutorIds(): Seq[String] = { schedulerBackend match { case b: CoarseGrainedSchedulerBackend => @@ -1538,7 +1635,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executors. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executors it kills * through this method with new ones, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1560,7 +1657,7 @@ class SparkContext(config: SparkConf) extends Logging { * :: DeveloperApi :: * Request that the cluster manager kill the specified executor. * - * Note: This is an indication to the cluster manager that the application wishes to adjust + * @note This is an indication to the cluster manager that the application wishes to adjust * its resource usage downwards. If the application wishes to replace the executor it kills * through this method with a new one, it should follow up explicitly with a call to * {{SparkContext#requestExecutors}}. @@ -1578,7 +1675,7 @@ class SparkContext(config: SparkConf) extends Logging { * this request. This assumes the cluster manager will automatically and eventually * fulfill all missing application resource requests. * - * Note: The replace is by no means guaranteed; another application on the same cluster + * @note The replace is by no means guaranteed; another application on the same cluster * can steal the window of opportunity and acquire this application's resources in the * mean time. * @@ -1627,7 +1724,8 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap @@ -1636,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus @@ -1697,9 +1796,9 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. - * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported - * filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. + * Adds a JAR dependency for all tasks to be executed on this `SparkContext` in the future. + * @param path can be either a local file, a file in HDFS (or other Hadoop-supported filesystems), + * an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node. */ def addJar(path: String) { if (path == null) { @@ -1716,29 +1815,20 @@ class SparkContext(config: SparkConf) extends Logging { key = uri.getScheme match { // A JAR file which exists only on the driver node case null | "file" => - if (master == "yarn" && deployMode == "cluster") { - // In order for this to work in yarn cluster mode the user must specify the - // --addJars option to the client to upload the file into the distributed cache - // of the AM to make it show up in the current working directory. - val fileName = new Path(uri.getPath).getName() - try { - env.rpcEnv.fileServer.addJar(new File(fileName)) - } catch { - case e: Exception => - // For now just log an error but allow to go through so spark examples work. - // The spark examples don't really need the jar distributed since its also - // the app jar. - logError("Error adding jar (" + e + "), was the --addJars option used?") - null + try { + val file = new File(uri.getPath) + if (!file.exists()) { + throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found") } - } else { - try { - env.rpcEnv.fileServer.addJar(new File(uri.getPath)) - } catch { - case exc: FileNotFoundException => - logError(s"Jar not found at $path") - null + if (file.isDirectory) { + throw new IllegalArgumentException( + s"Directory ${file.getAbsoluteFile} is not allowed for addJar") } + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) + } catch { + case NonFatal(e) => + logError(s"Failed to add $path to Spark environment", e) + null } // A JAR file which exists locally on every worker node case "local" => @@ -1762,8 +1852,31 @@ class SparkContext(config: SparkConf) extends Logging { */ def listJars(): Seq[String] = addedJars.keySet.toSeq - // Shut down the SparkContext. - def stop() { + /** + * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark + * may wait for some internal threads to finish. It's better to use this method to stop + * SparkContext instead. + */ + private[spark] def stopInNewThread(): Unit = { + new Thread("stop-spark-context") { + setDaemon(true) + + override def run(): Unit = { + try { + SparkContext.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + } + }.start() + } + + /** + * Shut down the SparkContext. + */ + def stop(): Unit = { if (LiveListenerBus.withinListenerThread.value) { throw new SparkException( s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}") @@ -1826,6 +1939,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() @@ -1883,6 +1999,12 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a function on a given set of partitions in an RDD and pass the results to the given * handler function. This is the main entry point for all actions in Spark. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @param resultHandler callback to pass each result to */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1905,6 +2027,14 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a function on a given set of partitions in an RDD and return the results as an array. + * The function that is run against each partition additionally takes `TaskContext` argument. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1916,8 +2046,14 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Run a job on a given set of partitions of an RDD, but take a function of type - * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. + * Run a function on a given set of partitions in an RDD and return the results as an array. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1928,7 +2064,13 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Run a job on all partitions in an RDD and return the results in an array. + * Run a job on all partitions in an RDD and return the results in an array. The function + * that is run against each partition additionally takes `TaskContext` argument. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.length) @@ -1936,13 +2078,23 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a job on all partitions in an RDD and return the results in an array. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @return in-memory collection with a result of the job (each collection element will contain + * a result from one partition) */ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.length) } /** - * Run a job on all partitions in an RDD and pass the results to a handler function. + * Run a job on all partitions in an RDD and pass the results to a handler function. The function + * that is run against each partition additionally takes `TaskContext` argument. + * + * @param rdd target RDD to run tasks on + * @param processPartition a function to run on each partition of the RDD + * @param resultHandler callback to pass each result to */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1954,6 +2106,10 @@ class SparkContext(config: SparkConf) extends Logging { /** * Run a job on all partitions in an RDD and pass the results to a handler function. + * + * @param rdd target RDD to run tasks on + * @param processPartition a function to run on each partition of the RDD + * @param resultHandler callback to pass each result to */ def runJob[T, U: ClassTag]( rdd: RDD[T], @@ -1967,6 +2123,13 @@ class SparkContext(config: SparkConf) extends Logging { /** * :: DeveloperApi :: * Run a job that can return approximate results. + * + * @param rdd target RDD to run tasks on + * @param func a function to run on each partition of the RDD + * @param evaluator `ApproximateEvaluator` to receive the partial results + * @param timeout maximum time to wait for the job, in milliseconds + * @return partial result (how partial depends on whether the job was finished before or + * after timeout) */ @DeveloperApi def runApproximateJob[T, U, R]( @@ -1988,6 +2151,13 @@ class SparkContext(config: SparkConf) extends Logging { /** * Submit a job for execution and return a FutureJob holding the result. + * + * @param rdd target RDD to run tasks on + * @param processPartition a function to run on each partition of the RDD + * @param partitions set of partitions to run on; some jobs may not want to compute on all + * partitions of the target RDD, e.g. for operations like `first()` + * @param resultHandler callback to pass each result to + * @param resultFunc function to be executed when the result is ready */ def submitJob[T, U, R]( rdd: RDD[T], @@ -2027,7 +2197,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]] + * Cancel active jobs for the specified group. See `org.apache.spark.SparkContext.setJobGroup` * for more information. */ def cancelJobGroup(groupId: String) { @@ -2045,20 +2215,60 @@ class SparkContext(config: SparkConf) extends Logging { * Cancel a given job if it's scheduled or running. * * @param jobId the job ID to cancel - * @throws InterruptedException if the cancel message cannot be sent + * @param reason optional reason for cancellation + * @note Throws `InterruptedException` if the cancel message cannot be sent + */ + def cancelJob(jobId: Int, reason: String): Unit = { + dagScheduler.cancelJob(jobId, Option(reason)) + } + + /** + * Cancel a given job if it's scheduled or running. + * + * @param jobId the job ID to cancel + * @note Throws `InterruptedException` if the cancel message cannot be sent */ - def cancelJob(jobId: Int) { - dagScheduler.cancelJob(jobId) + def cancelJob(jobId: Int): Unit = { + dagScheduler.cancelJob(jobId, None) } /** * Cancel a given stage and all jobs associated with it. * * @param stageId the stage ID to cancel - * @throws InterruptedException if the cancel message cannot be sent + * @param reason reason for cancellation + * @note Throws `InterruptedException` if the cancel message cannot be sent */ - def cancelStage(stageId: Int) { - dagScheduler.cancelStage(stageId) + def cancelStage(stageId: Int, reason: String): Unit = { + dagScheduler.cancelStage(stageId, Option(reason)) + } + + /** + * Cancel a given stage and all jobs associated with it. + * + * @param stageId the stage ID to cancel + * @note Throws `InterruptedException` if the cancel message cannot be sent + */ + def cancelStage(stageId: Int): Unit = { + dagScheduler.cancelStage(stageId, None) + } + + /** + * Kill and reschedule the given task attempt. Task ids can be obtained from the Spark UI + * or through SparkListener.onTaskStart. + * + * @param taskId the task ID to kill. This id uniquely identifies the task attempt. + * @param interruptThread whether to interrupt the thread running the task. + * @param reason the reason for killing the task, which should be a short string. If a task + * is killed multiple times with different reasons, only one reason will be reported. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt( + taskId: Long, + interruptThread: Boolean = true, + reason: String = "killed via SparkContext.killTaskAttempt"): Boolean = { + dagScheduler.killTaskAttempt(taskId, interruptThread, reason) } /** @@ -2072,6 +2282,7 @@ class SparkContext(config: SparkConf) extends Logging { * @param checkSerializable whether or not to immediately check f for serializability * @throws SparkException if checkSerializable is set but f is not * serializable + * @return the cleaned closure */ private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { ClosureCleaner.clean(f, checkSerializable) @@ -2079,8 +2290,9 @@ class SparkContext(config: SparkConf) extends Logging { } /** - * Set the directory under which RDDs are going to be checkpointed. The directory must - * be a HDFS path if running on a cluster. + * Set the directory under which RDDs are going to be checkpointed. + * @param directory path to the directory where checkpoint files will be stored + * (must be HDFS path if running in cluster) */ def setCheckpointDir(directory: String) { @@ -2285,8 +2497,10 @@ object SparkContext extends Logging { * singleton object. Because we can only have one active SparkContext per JVM, * this is useful when applications may wish to share a SparkContext. * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. + * @param config `SparkConfig` that will be used for initialisation of the `SparkContext` + * @return current `SparkContext` (or a new one if it wasn't created before the function call) */ def getOrCreate(config: SparkConf): SparkContext = { // Synchronize to ensure that multiple create requests don't trigger an exception @@ -2310,8 +2524,9 @@ object SparkContext extends Logging { * * This method allows not passing a SparkConf (useful if just retrieving). * - * Note: This function cannot be used to create multiple SparkContext instances + * @note This function cannot be used to create multiple SparkContext instances * even if multiple contexts are allowed. + * @return current `SparkContext` (or a new one if wasn't created before the function call) */ def getOrCreate(): SparkContext = { SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { @@ -2322,6 +2537,13 @@ object SparkContext extends Logging { } } + /** Return the current active [[SparkContext]] if any. */ + private[spark] def getActive: Option[SparkContext] = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + Option(activeContext.get()) + } + } + /** * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is * running. Throws an exception if a running context is detected and logs a warning if another @@ -2392,6 +2614,9 @@ object SparkContext extends Logging { /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to SparkContext. + * + * @param cls class that should be inside of the jar + * @return jar that contains the Class, `None` if not found */ def jarOfClass(cls: Class[_]): Option[String] = { val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class") @@ -2413,6 +2638,9 @@ object SparkContext extends Logging { * Find the JAR that contains the class of a particular object, to make it easy for users * to pass their JARs to SparkContext. In most cases you can call jarOfObject(this) in * your driver program. + * + * @param obj reference to an instance which class should be inside of the jar + * @return jar that contains the class of the instance, `None` if not found */ def jarOfObject(obj: AnyRef): Option[String] = jarOfClass(obj.getClass) @@ -2550,8 +2778,8 @@ object SparkContext extends Logging { val serviceLoaders = ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url)) if (serviceLoaders.size > 1) { - throw new SparkException(s"Multiple Cluster Managers ($serviceLoaders) registered " + - s"for the url $url:") + throw new SparkException( + s"Multiple external cluster managers registered for the url $url: $serviceLoaders") } serviceLoaders.headOption } @@ -2572,11 +2800,12 @@ private object SparkMasterRegex { } /** - * A class encapsulating how to convert some type T to Writable. It stores both the Writable class - * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. - * The getter for the writable class takes a ClassTag[T] in case this is a generic object - * that doesn't know the type of T when it is created. This sounds strange but is necessary to - * support converting subclasses of Writable to themselves (writableWritableConverter). + * A class encapsulating how to convert some type `T` from `Writable`. It stores both the `Writable` + * class corresponding to `T` (e.g. `IntWritable` for `Int`) and a function for doing the + * conversion. + * The getter for the writable class takes a `ClassTag[T]` in case this is a generic object + * that doesn't know the type of `T` when it is created. This sounds strange but is necessary to + * support converting subclasses of `Writable` to themselves (`writableWritableConverter()`). */ private[spark] class WritableConverter[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], @@ -2627,9 +2856,10 @@ object WritableConverter { } /** - * A class encapsulating how to convert some type T to Writable. It stores both the Writable class - * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. - * The Writable class will be used in `SequenceFileRDDFunctions`. + * A class encapsulating how to convert some type `T` to `Writable`. It stores both the `Writable` + * class corresponding to `T` (e.g. `IntWritable` for `Int`) and a function for doing the + * conversion. + * The `Writable` class will be used in `SequenceFileRDDFunctions`. */ private[spark] class WritableFactory[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1ffeb129880f9..3196c1ece15eb 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket +import java.util.Locale import scala.collection.mutable import scala.util.Properties @@ -36,6 +37,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage._ @@ -165,15 +167,20 @@ object SparkEnv extends Logging { val bindAddress = conf.get(DRIVER_BIND_ADDRESS) val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) val port = conf.get("spark.driver.port").toInt + val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(conf)) + } else { + None + } create( conf, SparkContext.DRIVER_IDENTIFIER, bindAddress, advertiseAddress, - port, - isDriver = true, - isLocal = isLocal, - numUsableCores = numCores, + Option(port), + isLocal, + numCores, + ioEncryptionKey, listenerBus = listenerBus, mockOutputCommitCoordinator = mockOutputCommitCoordinator ) @@ -187,18 +194,18 @@ object SparkEnv extends Logging { conf: SparkConf, executorId: String, hostname: String, - port: Int, numCores: Int, + ioEncryptionKey: Option[Array[Byte]], isLocal: Boolean): SparkEnv = { val env = create( conf, executorId, hostname, hostname, - port, - isDriver = false, - isLocal = isLocal, - numUsableCores = numCores + None, + isLocal, + numCores, + ioEncryptionKey ) SparkEnv.set(env) env @@ -212,32 +219,35 @@ object SparkEnv extends Logging { executorId: String, bindAddress: String, advertiseAddress: String, - port: Int, - isDriver: Boolean, + port: Option[Int], isLocal: Boolean, numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], listenerBus: LiveListenerBus = null, mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER + // Listener bus is only used on the driver if (isDriver) { assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") } - val securityManager = new SecurityManager(conf) + val securityManager = new SecurityManager(conf, ioEncryptionKey) + ioEncryptionKey.foreach { _ => + if (!securityManager.isEncryptionEnabled()) { + logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") + } + } val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf, + val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, securityManager, clientMode = !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - // In the non-driver case, the RPC env's address may be null since it may not be listening - // for incoming connections. if (isDriver) { conf.set("spark.driver.port", rpcEnv.address.port.toString) - } else if (rpcEnv.address != null) { - conf.set("spark.executor.port", rpcEnv.address.port.toString) - logInfo(s"Setting spark.executor.port to: ${rpcEnv.address.port.toString}") } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -270,7 +280,7 @@ object SparkEnv extends Logging { "spark.serializer", "org.apache.spark.serializer.JavaSerializer") logDebug(s"Using serializer: ${serializer.getClass}") - val serializerManager = new SerializerManager(serializer, conf) + val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) val closureSerializer = new JavaSerializer(conf) @@ -304,7 +314,8 @@ object SparkEnv extends Logging { "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") - val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) + val shuffleMgrClass = + shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) val useLegacyMemoryManager = conf.getBoolean("spark.memory.useLegacyMode", false) diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 52c4656c271bc..22a553e68439a 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -112,7 +112,7 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { */ def getExecutorInfos: Array[SparkExecutorInfo] = { val executorIdToRunningTasks: Map[String, Int] = - sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors() + sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors sc.getExecutorStorageStatus.map { status => val bmId = status.blockManagerId diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 27abccf5ac2a9..0b87cd503d4fa 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} @@ -104,7 +105,9 @@ abstract class TaskContext extends Serializable { /** * Adds a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -113,7 +116,9 @@ abstract class TaskContext extends Serializable { /** * Adds a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situations - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. Adding a listener + * to an already completed task will result in that listener being called immediately. + * * An example use is for HadoopRDD to register a callback to close the input stream. * * Exceptions thrown by the listener will result in failure of the task. @@ -125,14 +130,14 @@ abstract class TaskContext extends Serializable { } /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(listener: TaskFailureListener): TaskContext /** - * Adds a listener to be executed on task failure. - * Operations defined here must be idempotent, as `onTaskFailure` can be called multiple times. + * Adds a listener to be executed on task failure. Adding a listener to an already failed task + * will result in that listener being called immediately. */ def addTaskFailureListener(f: (TaskContext, Throwable) => Unit): TaskContext = { addTaskFailureListener(new TaskFailureListener { @@ -164,7 +169,7 @@ abstract class TaskContext extends Serializable { /** * Get a local property set upstream in the driver, or null if it is missing. See also - * [[org.apache.spark.SparkContext.setLocalProperty]]. + * `org.apache.spark.SparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String @@ -174,11 +179,21 @@ abstract class TaskContext extends Serializable { /** * ::DeveloperApi:: * Returns all metrics sources with the given name which are associated with the instance - * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + * which runs the task. For more information see `org.apache.spark.metrics.MetricsSystem`. */ @DeveloperApi def getMetricsSources(sourceName: String): Seq[Source] + /** + * If the task is interrupted, throws TaskKilledException with the reason for the interrupt. + */ + private[spark] def killTaskIfInterrupted(): Unit + + /** + * If the task is interrupted, the reason this task was killed, otherwise None. + */ + private[spark] def getKillReason(): Option[String] + /** * Returns the manager for this task's managed memory. */ @@ -190,4 +205,10 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit + /** + * Record that this task has failed due to a fetch failure from a remote host. This allows + * fetch-failure handling to get triggered by the driver, regardless of intervening user-code. + */ + private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c904e083911cd..01d8973e1bb06 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.util.Properties +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -26,8 +27,19 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ +/** + * A [[TaskContext]] implementation. + * + * A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes + * sure that updates are always visible across threads. The complete & failed flags and their + * callbacks are protected by locking on the context instance. For instance, this ensures + * that you cannot add a completion listener in one thread while we are completing (and calling + * the completion listeners) in another thread. Other state is immutable, however the exposed + * `TaskMetrics` & `MetricsSystem` objects are not thread safe. + */ private[spark] class TaskContextImpl( val stageId: Int, val partitionId: Int, @@ -47,75 +59,108 @@ private[spark] class TaskContextImpl( /** List of callback functions to execute when the task fails. */ @transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener] - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false + // If defined, the corresponding task has been killed and this option contains the reason. + @volatile private var reasonIfKilled: Option[String] = None // Whether the task has completed. - @volatile private var completed: Boolean = false + private var completed: Boolean = false // Whether the task has failed. - @volatile private var failed: Boolean = false - - override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener + private var failed: Boolean = false + + // Throwable that caused the task to fail + private var failure: Throwable = _ + + // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't + // hide the exception. See SPARK-19276 + @volatile private var _fetchFailedException: Option[FetchFailedException] = None + + @GuardedBy("this") + override def addTaskCompletionListener(listener: TaskCompletionListener) + : this.type = synchronized { + if (completed) { + listener.onTaskCompletion(this) + } else { + onCompleteCallbacks += listener + } this } - override def addTaskFailureListener(listener: TaskFailureListener): this.type = { - onFailureCallbacks += listener + @GuardedBy("this") + override def addTaskFailureListener(listener: TaskFailureListener) + : this.type = synchronized { + if (failed) { + listener.onTaskFailure(this, failure) + } else { + onFailureCallbacks += listener + } this } /** Marks the task as failed and triggers the failure listeners. */ - private[spark] def markTaskFailed(error: Throwable): Unit = { - // failure callbacks should only be called once + @GuardedBy("this") + private[spark] def markTaskFailed(error: Throwable): Unit = synchronized { if (failed) return failed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process failure callbacks in the reverse order of registration - onFailureCallbacks.reverse.foreach { listener => - try { - listener.onTaskFailure(this, error) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskFailureListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs, Option(error)) + failure = error + invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) { + _.onTaskFailure(this, error) } } /** Marks the task as completed and triggers the completion listeners. */ - private[spark] def markTaskCompleted(): Unit = { + @GuardedBy("this") + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { + if (completed) return completed = true + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { + _.onTaskCompletion(this) + } + } + + private def invokeListeners[T]( + listeners: Seq[T], + name: String, + error: Option[Throwable])( + callback: T => Unit): Unit = { val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => + // Process callbacks in the reverse order of registration + listeners.reverse.foreach { listener => try { - listener.onTaskCompletion(this) + callback(listener) } catch { case e: Throwable => errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) + logError(s"Error in $name", e) } } if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) + throw new TaskCompletionListenerException(errorMsgs, error) } } /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true + private[spark] def markInterrupted(reason: String): Unit = { + reasonIfKilled = Some(reason) + } + + private[spark] override def killTaskIfInterrupted(): Unit = { + val reason = reasonIfKilled + if (reason.isDefined) { + throw new TaskKilledException(reason.get) + } + } + + private[spark] override def getKillReason(): Option[String] = { + reasonIfKilled } - override def isCompleted(): Boolean = completed + @GuardedBy("this") + override def isCompleted(): Boolean = synchronized(completed) override def isRunningLocally(): Boolean = false - override def isInterrupted(): Boolean = interrupted + override def isInterrupted(): Boolean = reasonIfKilled.isDefined override def getLocalProperty(key: String): String = localProperties.getProperty(key) @@ -126,4 +171,10 @@ private[spark] class TaskContextImpl( taskMetrics.registerAccumulator(a) } + private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + this._fetchFailedException = Option(fetchFailed) + } + + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 7ca3c103dbf5b..a76283e33fa65 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -65,7 +65,7 @@ sealed trait TaskFailedReason extends TaskEndReason { /** * :: DeveloperApi :: - * A [[org.apache.spark.scheduler.ShuffleMapTask]] that completed successfully earlier, but we + * A `org.apache.spark.scheduler.ShuffleMapTask` that completed successfully earlier, but we * lost the executor before the stage completed. This means Spark needs to reschedule the task * to be re-executed on a different executor. */ @@ -98,7 +98,7 @@ case class FetchFailed( * 4 task failures, instead we immediately go back to the stage which generated the map output, * and regenerate the missing data. (2) we don't count fetch failures for blacklisting, since * presumably its not the fault of the executor where the task ran, but the executor which - * stored the data. This is especially important because we we might rack up a bunch of + * stored the data. This is especially important because we might rack up a bunch of * fetch-failures in rapid succession, on all nodes of the cluster, due to one bad node. */ override def countTowardsTaskFailures: Boolean = false @@ -212,8 +212,8 @@ case object TaskResultLost extends TaskFailedReason { * Task was killed intentionally and needs to be rescheduled. */ @DeveloperApi -case object TaskKilled extends TaskFailedReason { - override def toErrorString: String = "TaskKilled (killed intentionally)" +case class TaskKilled(reason: String) extends TaskFailedReason { + override def toErrorString: String = s"TaskKilled ($reason)" override def countTowardsTaskFailures: Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/TaskKilledException.scala b/core/src/main/scala/org/apache/spark/TaskKilledException.scala index ad487c4efb87a..9dbf0d493be11 100644 --- a/core/src/main/scala/org/apache/spark/TaskKilledException.scala +++ b/core/src/main/scala/org/apache/spark/TaskKilledException.scala @@ -24,4 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * Exception thrown when a task is explicitly killed (i.e., task failure is expected). */ @DeveloperApi -class TaskKilledException extends RuntimeException +class TaskKilledException(val reason: String) extends RuntimeException { + def this() = this("unknown reason") +} diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 871b9d1ad575b..3f912dc191515 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -18,19 +18,23 @@ package org.apache.spark import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} -import java.net.{URI, URL} +import java.net.{HttpURLConnection, URI, URL} import java.nio.charset.StandardCharsets -import java.nio.file.Paths +import java.security.SecureRandom +import java.security.cert.X509Certificate import java.util.Arrays import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl._ +import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.sys.process.{Process, ProcessLogger} +import scala.util.Try import com.google.common.io.{ByteStreams, Files} -import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ @@ -93,7 +97,10 @@ private[spark] object TestUtils { val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) for (file <- files) { - val jarEntry = new JarEntry(Paths.get(directoryPrefix.getOrElse(""), file.getName).toString) + // The `name` for the argument in `JarEntry` should use / for its separator. This is + // ZIP specification. + val prefix = directoryPrefix.map(d => s"$d/").getOrElse("") + val jarEntry = new JarEntry(prefix + file.getName) jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) @@ -182,11 +189,54 @@ private[spark] object TestUtils { assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } + /** + * Test if a command is available. + */ + def testCommandAvailable(command: String): Boolean = { + val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue()) + attempt.isSuccess && attempt.get == 0 + } + + /** + * Returns the response code from an HTTP(S) URL. + */ + def httpResponseCode( + url: URL, + method: String = "GET", + headers: Seq[(String, String)] = Nil): Int = { + val connection = url.openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestMethod(method) + headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } + + // Disable cert and host name validation for HTTPS tests. + if (connection.isInstanceOf[HttpsURLConnection]) { + val sslCtx = SSLContext.getInstance("SSL") + val trustManager = new X509TrustManager { + override def getAcceptedIssuers(): Array[X509Certificate] = null + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + } + val verifier = new HostnameVerifier() { + override def verify(hostname: String, session: SSLSession): Boolean = true + } + sslCtx.init(null, Array(trustManager), new SecureRandom()) + connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) + connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + } + + try { + connection.connect() + connection.getResponseCode() + } finally { + connection.disconnect() + } + } + } /** - * A [[SparkListener]] that detects whether spills have occurred in Spark jobs. + * A `SparkListener` that detects whether spills have occurred in Spark jobs. */ private class SpillListener extends SparkListener { private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 0026fc9dad517..b71af0d42cdb0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -45,7 +45,9 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) import JavaDoubleRDD.fromRDD - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) /** @@ -153,7 +155,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) @@ -256,7 +258,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 1c95bc4bfcaaf..9544475ff0428 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -54,7 +54,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) // Common RDD functions - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) /** @@ -164,7 +166,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * This method differs from `sampleByKey` in that we make additional passes over the RDD to * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) * over all key values with a 99.99% confidence. When sampling without replacement, we need one * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need @@ -182,7 +184,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return a subset of this RDD sampled by key (via stratified sampling) containing exactly * math.ceil(numItems * samplingRate) for each stratum (group of pairs with the same key). * - * This method differs from [[sampleByKey]] in that we make additional passes over the RDD to + * This method differs from `sampleByKey` in that we make additional passes over the RDD to * create a sample size that's exactly equal to the sum of math.ceil(numItems * samplingRate) * over all key values with a 99.99% confidence. When sampling without replacement, we need one * additional pass over the RDD to guarantee sample size; when sampling with replacement, we need @@ -206,7 +208,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.intersection(other.rdd)) @@ -223,9 +225,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -234,6 +236,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * In addition, users can control the partitioning of the output RDD, the serializer that is use * for the shuffle, and whether to perform map-side aggregation (if a mapper can produce multiple * items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -255,9 +260,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C. Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three - * functions: + * "combined type" C. + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -265,6 +270,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD. This method automatically * uses map-side aggregation in shuffling the RDD. + * + * @note V and C can be different -- for example, one might group an RDD of type (Int, Int) into + * an RDD of type (Int, List[Int]). */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], @@ -398,8 +406,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = @@ -409,8 +417,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = @@ -448,13 +456,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.subtractByKey(other)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag fromRDD(rdd.subtractByKey(other, numPartitions)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W](other: JavaPairRDD[K, W], p: Partitioner): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag fromRDD(rdd.subtractByKey(other, p)) @@ -539,8 +551,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. * - * Note: If you are grouping in order to perform an aggregation (such as a sum or average) over - * each key, using [[JavaPairRDD.reduceByKey]] or [[JavaPairRDD.combineByKey]] + * @note If you are grouping in order to perform an aggregation (such as a sum or average) over + * each key, using `JavaPairRDD.reduceByKey` or `JavaPairRDD.combineByKey` * will provide much better performance. */ def groupByKey(): JavaPairRDD[K, JIterable[V]] = diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 20d6c9341bf7a..41b5cab601c36 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -34,7 +34,9 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) // Common RDD functions - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): JavaRDD[T] = wrapRDD(rdd.cache()) /** @@ -98,24 +100,32 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD with a random seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = sample(withReplacement, fraction, Utils.random.nextLong) /** - * Return a sampled subset of this RDD. + * Return a sampled subset of this RDD, with a user-supplied seed. * * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given `RDD`. */ def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) @@ -153,7 +163,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) @@ -161,7 +171,7 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be less than or equal to us. */ def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index a37c52cbaf210..91ae1002abd21 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -47,7 +47,8 @@ private[spark] abstract class AbstractJavaRDDLike[T, This <: JavaRDDLike[T, This /** * Defines operations common to several Java RDD implementations. - * Note that this trait is not intended to be implemented by user code. + * + * @note This trait is not intended to be implemented by user code. */ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This @@ -392,7 +393,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) /** - * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + * `org.apache.spark.api.java.JavaRDDLike.treeReduce` with suggested depth 2. */ def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) @@ -439,7 +440,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + * `org.apache.spark.api.java.JavaRDDLike.treeAggregate` with suggested depth 2. */ def treeAggregate[U]( zeroValue: U, diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4e50c2686dd53..9481156bc93a5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -238,7 +238,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}} * * then `rdd` contains * {{{ @@ -270,7 +272,9 @@ class JavaSparkContext(val sc: SparkContext) * }}} * * Do - * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * {{{ + * JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path") + * }}}, * * then `rdd` contains * {{{ @@ -298,7 +302,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile with given key and value types. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -316,7 +320,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop SequenceFile. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -366,7 +370,7 @@ class JavaSparkContext(val sc: SparkContext) * @param valueClass Class of the values * @param minPartitions Minimum number of Hadoop Splits to generate. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -396,7 +400,7 @@ class JavaSparkContext(val sc: SparkContext) * @param keyClass Class of the keys * @param valueClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -416,7 +420,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -437,7 +441,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get an RDD for a Hadoop file with an arbitrary InputFormat * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -458,7 +462,7 @@ class JavaSparkContext(val sc: SparkContext) * Get an RDD for a given Hadoop file with an arbitrary new API InputFormat * and extra configuration options to pass to the input format. * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -487,7 +491,7 @@ class JavaSparkContext(val sc: SparkContext) * @param kClass Class of the keys * @param vClass Class of the values * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * @note Because Hadoop's RecordReader class re-uses the same Writable object for each * record, directly caching the returned RDD will create many references to the same object. * If you plan to directly cache Hadoop writable objects, you should first copy them using * a `map` function. @@ -694,7 +698,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * - * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you + * @note As it will be reused in all Hadoop RDDs, it's better not to modify it unless you * plan to set some global configurations for all Hadoop RDDs. */ def hadoopConfiguration(): Configuration = { @@ -749,7 +753,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Get a local property set in this thread, or null if it is missing. See - * [[org.apache.spark.api.java.JavaSparkContext.setLocalProperty]]. + * `org.apache.spark.api.java.JavaSparkContext.setLocalProperty`. */ def getLocalProperty(key: String): String = sc.getLocalProperty(key) @@ -769,7 +773,7 @@ class JavaSparkContext(val sc: SparkContext) * Application programmers can use this method to group all those jobs together and give a * group description. Once set, the Spark web UI will associate such jobs with this group. * - * The application can also use [[org.apache.spark.api.java.JavaSparkContext.cancelJobGroup]] + * The application can also use `org.apache.spark.api.java.JavaSparkContext.cancelJobGroup` * to cancel all running jobs in this group. For example, * {{{ * // In the main thread: @@ -802,7 +806,7 @@ class JavaSparkContext(val sc: SparkContext) /** * Cancel active jobs for the specified group. See - * [[org.apache.spark.api.java.JavaSparkContext.setJobGroup]] for more information. + * `org.apache.spark.api.java.JavaSparkContext.setJobGroup` for more information. */ def cancelJobGroup(groupId: String): Unit = sc.cancelJobGroup(groupId) @@ -811,7 +815,8 @@ class JavaSparkContext(val sc: SparkContext) /** * Returns a Java map of JavaRDDs that have marked themselves as persistent via cache() call. - * Note that this does not necessarily mean the caching or computation was successful. + * + * @note This does not necessarily mean the caching or computation was successful. */ def getPersistentRDDs: JMap[java.lang.Integer, JavaRDD[_]] = { sc.getPersistentRDDs.mapValues(s => JavaRDD.fromRDD(s)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 99ca3c77cced0..6aa290ecd7bb5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} * will provide information for the last `spark.ui.retainedStages` stages and * `spark.ui.retainedJobs` jobs. * - * NOTE: this class's constructor should be considered private and may be subject to change. + * @note This class's constructor should be considered private and may be subject to change. */ class JavaSparkStatusTracker private[spark] (sc: SparkContext) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0ca91b9bf86c6..fb0405b1a69c6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -215,7 +215,7 @@ private[spark] class PythonRunner( case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) case e: Exception if env.isStopped => logDebug("Exception thrown after context is stopped", e) @@ -275,6 +275,11 @@ private[spark] class PythonRunner( dataOut.writeInt(partitionIndex) // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) @@ -874,7 +879,7 @@ private[spark] class PythonAccumulatorV2( private val serverPort: Int) extends CollectionAccumulator[Array[Byte]] { - Utils.checkHost(serverHost, "Expected hostname") + Utils.checkHost(serverHost) val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) diff --git a/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala new file mode 100644 index 0000000000000..3432700f11602 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/JVMObjectTracker.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.ConcurrentHashMap + +/** JVM object ID wrapper */ +private[r] case class JVMObjectId(id: String) { + require(id != null, "Object ID cannot be null.") +} + +/** + * Counter that tracks JVM objects returned to R. + * This is useful for referencing these objects in RPC calls. + */ +private[r] class JVMObjectTracker { + + private[this] val objMap = new ConcurrentHashMap[JVMObjectId, Object]() + private[this] val objCounter = new AtomicInteger() + + /** + * Returns the JVM object associated with the input key or None if not found. + */ + final def get(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.get(id)) + } else { + None + } + } + + /** + * Returns the JVM object associated with the input key or throws an exception if not found. + */ + @throws[NoSuchElementException]("if key does not exist.") + final def apply(id: JVMObjectId): Object = { + get(id).getOrElse( + throw new NoSuchElementException(s"$id does not exist.") + ) + } + + /** + * Adds a JVM object to track and returns assigned ID, which is unique within this tracker. + */ + final def addAndGetId(obj: Object): JVMObjectId = { + val id = JVMObjectId(objCounter.getAndIncrement().toString) + objMap.put(id, obj) + id + } + + /** + * Removes and returns a JVM object with the specific ID from the tracker, or None if not found. + */ + final def remove(id: JVMObjectId): Option[Object] = this.synchronized { + if (objMap.containsKey(id)) { + Some(objMap.remove(id)) + } else { + None + } + } + + /** + * Number of JVM objects being tracked. + */ + final def size: Int = objMap.size() + + /** + * Clears the tracker. + */ + final def clear(): Unit = objMap.clear() +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 550746c552d02..2d1152a036449 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -22,7 +22,7 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup} +import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel @@ -42,6 +42,9 @@ private[spark] class RBackend { private[this] var bootstrap: ServerBootstrap = null private[this] var bossGroup: EventLoopGroup = null + /** Tracks JVM objects returned to R for this RBackend instance. */ + private[r] val jvmObjectTracker = new JVMObjectTracker + def init(): Int = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( @@ -94,6 +97,7 @@ private[spark] class RBackend { bootstrap.childGroup().shutdownGracefully() } bootstrap = null + jvmObjectTracker.clear() } } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 9f5afa29d6d22..cfd37ac54ba23 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,7 +20,6 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.TimeUnit -import scala.collection.mutable.HashMap import scala.language.existentials import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} @@ -62,7 +61,7 @@ private[r] class RBackendHandler(server: RBackend) assert(numArgs == 1) writeInt(dos, 0) - writeObject(dos, args(0)) + writeObject(dos, args(0), server.jvmObjectTracker) case "stopBackend" => writeInt(dos, 0) writeType(dos, "void") @@ -72,9 +71,9 @@ private[r] class RBackendHandler(server: RBackend) val t = readObjectType(dis) assert(t == 'c') val objToRemove = readString(dis) - JVMObjectTracker.remove(objToRemove) + server.jvmObjectTracker.remove(JVMObjectId(objToRemove)) writeInt(dos, 0) - writeObject(dos, null) + writeObject(dos, null, server.jvmObjectTracker) } catch { case e: Exception => logError(s"Removing $objId failed", e) @@ -143,12 +142,8 @@ private[r] class RBackendHandler(server: RBackend) val cls = if (isStatic) { Utils.classForName(objId) } else { - JVMObjectTracker.get(objId) match { - case None => throw new IllegalArgumentException("Object not found " + objId) - case Some(o) => - obj = o - o.getClass - } + obj = server.jvmObjectTracker(JVMObjectId(objId)) + obj.getClass } val args = readArgs(numArgs, dis) @@ -173,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend) // Write status bit writeInt(dos, 0) - writeObject(dos, ret.asInstanceOf[AnyRef]) + writeObject(dos, ret.asInstanceOf[AnyRef], server.jvmObjectTracker) } else if (methodName == "") { // methodName should be "" for constructor val ctors = cls.getConstructors @@ -193,7 +188,7 @@ private[r] class RBackendHandler(server: RBackend) val obj = ctors(index.get).newInstance(args : _*) writeInt(dos, 0) - writeObject(dos, obj.asInstanceOf[AnyRef]) + writeObject(dos, obj.asInstanceOf[AnyRef], server.jvmObjectTracker) } else { throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId) } @@ -210,7 +205,7 @@ private[r] class RBackendHandler(server: RBackend) // Read a number of arguments from the data input stream def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = { (0 until numArgs).map { _ => - readObject(dis) + readObject(dis, server.jvmObjectTracker) }.toArray } @@ -286,37 +281,4 @@ private[r] class RBackendHandler(server: RBackend) } } -/** - * Helper singleton that tracks JVM objects returned to R. - * This is useful for referencing these objects in RPC calls. - */ -private[r] object JVMObjectTracker { - - // TODO: This map should be thread-safe if we want to support multiple - // connections at the same time - private[this] val objMap = new HashMap[String, Object] - - // TODO: We support only one connection now, so an integer is fine. - // Investigate using use atomic integer in the future. - private[this] var objCounter: Int = 0 - - def getObject(id: String): Object = { - objMap(id) - } - - def get(id: String): Option[Object] = { - objMap.get(id) - } - - def put(obj: Object): String = { - val objId = objCounter.toString - objCounter = objCounter + 1 - objMap.put(objId, obj) - objId - } - def remove(id: String): Option[Object] = { - objMap.remove(id) - } - -} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index a1a5eb8cf55e8..295355c7bf018 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.r +import java.io.File import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -127,7 +128,15 @@ private[r] object RRDD { sparkConf.setExecutorEnv(name.toString, value.toString) } - val jsc = new JavaSparkContext(sparkConf) + if (sparkEnvirMap.containsKey("spark.r.sql.derby.temp.dir") && + System.getProperty("derby.stream.error.file") == null) { + // This must be set before SparkContext is instantiated. + System.setProperty("derby.stream.error.file", + Seq(sparkEnvirMap.get("spark.r.sql.derby.temp.dir").toString, "derby.log") + .mkString(File.separator)) + } + + val jsc = new JavaSparkContext(SparkContext.getOrCreate(sparkConf)) jars.foreach { jar => jsc.addJar(jar) } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 7ef64723d9593..88118392003e8 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -152,7 +152,7 @@ private[spark] class RRunner[U]( dataOut.writeInt(mode) if (isDataFrame) { - SerDe.writeObject(dataOut, colNames) + SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null) } if (!iter.hasNext) { @@ -347,6 +347,8 @@ private[r] object RRunner { pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) pb.environment().put("SPARKR_WORKER_PORT", port.toString) pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) + pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) + pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala index 77825e75e5136..fdd8cf62f0e5f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala @@ -84,7 +84,6 @@ private[spark] object RUtils { } } else { // Otherwise, assume the package is local - // TODO: support this for Mesos val sparkRPkgPath = localSparkRPackagePath.getOrElse { throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.") } diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 550e075a95129..dad928cdcfd0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -28,13 +28,20 @@ import scala.collection.mutable.WrappedArray * Utility functions to serialize, deserialize objects to / from R */ private[spark] object SerDe { - type ReadObject = (DataInputStream, Char) => Object - type WriteObject = (DataOutputStream, Object) => Boolean + type SQLReadObject = (DataInputStream, Char) => Object + type SQLWriteObject = (DataOutputStream, Object) => Boolean - var sqlSerDe: (ReadObject, WriteObject) = _ + private[this] var sqlReadObject: SQLReadObject = _ + private[this] var sqlWriteObject: SQLWriteObject = _ - def registerSqlSerDe(sqlSerDe: (ReadObject, WriteObject)): Unit = { - this.sqlSerDe = sqlSerDe + def setSQLReadObject(value: SQLReadObject): this.type = { + sqlReadObject = value + this + } + + def setSQLWriteObject(value: SQLWriteObject): this.type = { + sqlWriteObject = value + this } // Type mapping from R to Java @@ -56,32 +63,33 @@ private[spark] object SerDe { dis.readByte().toChar } - def readObject(dis: DataInputStream): Object = { + def readObject(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Object = { val dataType = readObjectType(dis) - readTypedObject(dis, dataType) + readTypedObject(dis, dataType, jvmObjectTracker) } def readTypedObject( dis: DataInputStream, - dataType: Char): Object = { + dataType: Char, + jvmObjectTracker: JVMObjectTracker): Object = { dataType match { case 'n' => null case 'i' => new java.lang.Integer(readInt(dis)) case 'd' => new java.lang.Double(readDouble(dis)) case 'b' => new java.lang.Boolean(readBoolean(dis)) case 'c' => readString(dis) - case 'e' => readMap(dis) + case 'e' => readMap(dis, jvmObjectTracker) case 'r' => readBytes(dis) - case 'a' => readArray(dis) - case 'l' => readList(dis) + case 'a' => readArray(dis, jvmObjectTracker) + case 'l' => readList(dis, jvmObjectTracker) case 'D' => readDate(dis) case 't' => readTime(dis) - case 'j' => JVMObjectTracker.getObject(readString(dis)) + case 'j' => jvmObjectTracker(JVMObjectId(readString(dis))) case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { - val obj = (sqlSerDe._1)(dis, dataType) + val obj = sqlReadObject(dis, dataType) if (obj == null) { throw new IllegalArgumentException (s"Invalid type $dataType") } else { @@ -181,28 +189,28 @@ private[spark] object SerDe { } // All elements of an array must be of the same type - def readArray(dis: DataInputStream): Array[_] = { + def readArray(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[_] = { val arrType = readObjectType(dis) arrType match { case 'i' => readIntArr(dis) case 'c' => readStringArr(dis) case 'd' => readDoubleArr(dis) case 'b' => readBooleanArr(dis) - case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x)) + case 'j' => readStringArr(dis).map(x => jvmObjectTracker(JVMObjectId(x))) case 'r' => readBytesArr(dis) case 'a' => val len = readInt(dis) - (0 until len).map(_ => readArray(dis)).toArray + (0 until len).map(_ => readArray(dis, jvmObjectTracker)).toArray case 'l' => val len = readInt(dis) - (0 until len).map(_ => readList(dis)).toArray + (0 until len).map(_ => readList(dis, jvmObjectTracker)).toArray case _ => - if (sqlSerDe == null || sqlSerDe._1 == null) { + if (sqlReadObject == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { val len = readInt(dis) (0 until len).map { _ => - val obj = (sqlSerDe._1)(dis, arrType) + val obj = sqlReadObject(dis, arrType) if (obj == null) { throw new IllegalArgumentException (s"Invalid array type $arrType") } else { @@ -215,17 +223,19 @@ private[spark] object SerDe { // Each element of a list can be of different type. They are all represented // as Object on JVM side - def readList(dis: DataInputStream): Array[Object] = { + def readList(dis: DataInputStream, jvmObjectTracker: JVMObjectTracker): Array[Object] = { val len = readInt(dis) - (0 until len).map(_ => readObject(dis)).toArray + (0 until len).map(_ => readObject(dis, jvmObjectTracker)).toArray } - def readMap(in: DataInputStream): java.util.Map[Object, Object] = { + def readMap( + in: DataInputStream, + jvmObjectTracker: JVMObjectTracker): java.util.Map[Object, Object] = { val len = readInt(in) if (len > 0) { // Keys is an array of String - val keys = readArray(in).asInstanceOf[Array[Object]] - val values = readList(in) + val keys = readArray(in, jvmObjectTracker).asInstanceOf[Array[Object]] + val values = readList(in, jvmObjectTracker) keys.zip(values).toMap.asJava } else { @@ -272,7 +282,11 @@ private[spark] object SerDe { } } - private def writeKeyValue(dos: DataOutputStream, key: Object, value: Object): Unit = { + private def writeKeyValue( + dos: DataOutputStream, + key: Object, + value: Object, + jvmObjectTracker: JVMObjectTracker): Unit = { if (key == null) { throw new IllegalArgumentException("Key in map can't be null.") } else if (!key.isInstanceOf[String]) { @@ -280,10 +294,10 @@ private[spark] object SerDe { } writeString(dos, key.asInstanceOf[String]) - writeObject(dos, value) + writeObject(dos, value, jvmObjectTracker) } - def writeObject(dos: DataOutputStream, obj: Object): Unit = { + def writeObject(dos: DataOutputStream, obj: Object, jvmObjectTracker: JVMObjectTracker): Unit = { if (obj == null) { writeType(dos, "void") } else { @@ -373,14 +387,14 @@ private[spark] object SerDe { case v: Array[Object] => writeType(dos, "list") writeInt(dos, v.length) - v.foreach(elem => writeObject(dos, elem)) + v.foreach(elem => writeObject(dos, elem, jvmObjectTracker)) // Handle Properties // This must be above the case java.util.Map below. // (Properties implements Map and will be serialized as map otherwise) case v: java.util.Properties => writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) // Handle map case v: java.util.Map[_, _] => @@ -392,19 +406,21 @@ private[spark] object SerDe { val key = entry.getKey val value = entry.getValue - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + writeKeyValue( + dos, key.asInstanceOf[Object], value.asInstanceOf[Object], jvmObjectTracker) } case v: scala.collection.Map[_, _] => writeType(dos, "map") writeInt(dos, v.size) - v.foreach { case (key, value) => - writeKeyValue(dos, key.asInstanceOf[Object], value.asInstanceOf[Object]) + v.foreach { case (k1, v1) => + writeKeyValue(dos, k1.asInstanceOf[Object], v1.asInstanceOf[Object], jvmObjectTracker) } case _ => - if (sqlSerDe == null || sqlSerDe._2 == null || !(sqlSerDe._2)(dos, value)) { + val sqlWriteSucceeded = sqlWriteObject != null && sqlWriteObject(dos, value) + if (!sqlWriteSucceeded) { writeType(dos, "jobj") - writeJObj(dos, value) + writeJObj(dos, value, jvmObjectTracker) } } } @@ -447,9 +463,9 @@ private[spark] object SerDe { out.write(value) } - def writeJObj(out: DataOutputStream, value: Object): Unit = { - val objId = JVMObjectTracker.put(value) - writeString(out, objId) + def writeJObj(out: DataOutputStream, value: Object, jvmObjectTracker: JVMObjectTracker): Unit = { + val JVMObjectId(id) = jvmObjectTracker.addAndGetId(value) + writeString(out, id) } def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index fd7b4fc88b697..ece4ae6ab0310 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -24,9 +24,8 @@ import org.apache.spark.SparkConf /** * An interface for all the broadcast implementations in Spark (to allow - * multiple broadcast implementations). SparkContext uses a user-specified - * BroadcastFactory implementation to instantiate a particular broadcast for the - * entire Spark job. + * multiple broadcast implementations). SparkContext uses a BroadcastFactory + * implementation to instantiate a particular broadcast for the entire Spark job. */ private[spark] trait BroadcastFactory { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e8d6d587b4824..039df75ce74fd 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io._ import java.nio.ByteBuffer +import java.util.zip.Adler32 import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -28,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} +import org.apache.spark.storage._ import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -77,6 +78,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } // Note: use getSizeAsKb (not bytes) to maintain compatibility if no units are provided blockSize = conf.getSizeAsKb("spark.broadcast.blockSize", "4m").toInt * 1024 + checksumEnabled = conf.getBoolean("spark.broadcast.checksum", true) } setConf(SparkEnv.get.conf) @@ -85,10 +87,27 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** Total number of blocks this broadcast variable contains. */ private val numBlocks: Int = writeBlocks(obj) + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = _ + override protected def getValue() = { _value } + private def calcChecksum(block: ByteBuffer): Int = { + val adler = new Adler32() + if (block.hasArray) { + adler.update(block.array, block.arrayOffset + block.position, block.limit - block.position) + } else { + val bytes = new Array[Byte](block.remaining()) + block.duplicate.get(bytes) + adler.update(bytes) + } + adler.getValue.toInt + } + /** * Divide the object into multiple blocks and put those blocks in the block manager. * @@ -105,7 +124,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } val blocks = TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + if (checksumEnabled) { + checksums = new Array[Int](blocks.length) + } blocks.zipWithIndex.foreach { case (block, i) => + if (checksumEnabled) { + checksums(i) = calcChecksum(block) + } val pieceId = BroadcastBlockId(id, "piece" + i) val bytes = new ChunkedByteBuffer(block.duplicate()) if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { @@ -116,10 +141,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ChunkedByteBuffer] = { + private def readBlocks(): Array[BlockData] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ChunkedByteBuffer](numBlocks) + val blocks = new Array[BlockData](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -135,13 +160,20 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) case None => bm.getRemoteBytes(pieceId) match { case Some(b) => + if (checksumEnabled) { + val sum = calcChecksum(b.chunks(0)) + if (sum != checksums(pid)) { + throw new SparkException(s"corrupt remote block $pieceId of $broadcastId:" + + s" $sum != ${checksums(pid)}") + } + } // We found the block from remote executors/driver's BlockManager, so put the block // in this executor's BlockManager. if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } - blocks(pid) = b + blocks(pid) = new ByteBufferBlockData(b, true) case None => throw new SparkException(s"Failed to get $pieceId of $broadcastId") } @@ -175,26 +207,34 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) TorrentBroadcast.synchronized { setConf(SparkEnv.get.conf) val blockManager = SparkEnv.get.blockManager - blockManager.getLocalValues(broadcastId).map(_.data.next()) match { - case Some(x) => - releaseLock(broadcastId) - x.asInstanceOf[T] - + blockManager.getLocalValues(broadcastId) match { + case Some(blockResult) => + if (blockResult.data.hasNext) { + val x = blockResult.data.next().asInstanceOf[T] + releaseLock(broadcastId) + x + } else { + throw new SparkException(s"Failed to get locally stored broadcast data: $broadcastId") + } case None => logInfo("Started reading broadcast variable " + id) val startTimeMs = System.currentTimeMillis() - val blocks = readBlocks().flatMap(_.getChunks()) + val blocks = readBlocks() logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - val obj = TorrentBroadcast.unBlockifyObject[T]( - blocks, SparkEnv.get.serializer, compressionCodec) - // Store the merged copy in BlockManager so other tasks on this executor don't - // need to re-fetch it. - val storageLevel = StorageLevel.MEMORY_AND_DISK - if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { - throw new SparkException(s"Failed to store $broadcastId in BlockManager") + try { + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks.map(_.toInputStream()), SparkEnv.get.serializer, compressionCodec) + // Store the merged copy in BlockManager so other tasks on this executor don't + // need to re-fetch it. + val storageLevel = StorageLevel.MEMORY_AND_DISK + if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) { + throw new SparkException(s"Failed to store $broadcastId in BlockManager") + } + obj + } finally { + blocks.foreach(_.dispose()) } - obj } } } @@ -241,12 +281,11 @@ private object TorrentBroadcast extends Logging { } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[InputStream], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") - val is = new SequenceInputStream( - blocks.iterator.map(new ByteBufferInputStream(_)).asJavaEnumeration) + val is = new SequenceInputStream(blocks.iterator.asJavaEnumeration) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index ee276e1b71138..bf6093236d92b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -123,7 +123,7 @@ private class ClientEndpoint( Thread.sleep(5000) logInfo("... polling master for driver state") val statusResponse = - activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) + activeMasterEndpoint.askSync[DriverStatusResponse](RequestDriverStatus(driverId)) if (statusResponse.found) { logInfo(s"State of $driverId is ${statusResponse.state.get}") // Worker node, if present @@ -221,7 +221,9 @@ object Client { val conf = new SparkConf() val driverArgs = new ClientArguments(args) - conf.set("spark.rpc.askTimeout", "10") + if (!conf.contains("spark.rpc.askTimeout")) { + conf.set("spark.rpc.askTimeout", "10s") + } Logger.getRootLogger.setLevel(driverArgs.logLevel) val rpcEnv = diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index ac09c6c497f8b..b5cb3f0a0f9dc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -43,7 +43,7 @@ private[deploy] object DeployMessages { memory: Int, workerWebUiUrl: String) extends DeployMessage { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } @@ -131,7 +131,7 @@ private[deploy] object DeployMessages { // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { - Utils.checkHostPort(hostPort, "Required hostport") + Utils.checkHostPort(hostPort) } case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String], @@ -183,7 +183,7 @@ private[deploy] object DeployMessages { completedDrivers: Array[DriverInfo], status: MasterState) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) def uri: String = "spark://" + host + ":" + port @@ -201,7 +201,7 @@ private[deploy] object DeployMessages { drivers: List[DriverRunner], finishedDrivers: List[DriverRunner], masterUrl: String, cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) { - Utils.checkHost(host, "Required hostname") + Utils.checkHost(host) assert (port > 0) } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 13eadbe44f612..8d491ddf6e092 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -25,8 +25,8 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.TransportContext +import org.apache.spark.network.crypto.AuthServerBootstrap import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.SaslServerBootstrap import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf @@ -47,7 +47,6 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) - private val useSasl: Boolean = securityManager.isAuthenticationEnabled() private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", numUsableCores = 0) @@ -74,10 +73,11 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana /** Start the external shuffle service */ def start() { require(server == null, "Shuffle server already started") - logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + val authEnabled = securityManager.isAuthenticationEnabled() + logInfo(s"Starting shuffle service on port $port (auth enabled = $authEnabled)") val bootstraps: Seq[TransportServerBootstrap] = - if (useSasl) { - Seq(new SaslServerBootstrap(transportConf, securityManager)) + if (authEnabled) { + Seq(new AuthServerBootstrap(transportConf, securityManager)) } else { Nil } diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala index e917679c83877..357a9769311a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleServiceSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy import javax.annotation.concurrent.ThreadSafe -import com.codahale.metrics.{Gauge, MetricRegistry} +import com.codahale.metrics.MetricRegistry import org.apache.spark.metrics.source.Source import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 79f4d06c8460e..c6307da61c7eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -43,8 +43,7 @@ import org.apache.spark.util.{ThreadUtils, Utils} * Execute using * ./bin/spark-class org.apache.spark.deploy.FaultToleranceTest * - * Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS - * *and* SPARK_JAVA_OPTS: + * Make sure that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS: * - spark.deploy.recoveryMode=ZOOKEEPER * - spark.deploy.zookeeper.url=172.17.42.1:2181 * Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port. diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0b1cec2df8303..a8f732b11f6cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -85,6 +85,7 @@ object PythonRunner { // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) + sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize try { val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index 3d2cabcdfdd5d..050778a895c0f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -176,26 +176,31 @@ private[deploy] object RPackageUtils extends Logging { val file = new File(Utils.resolveURI(jarPath)) if (file.exists()) { val jar = new JarFile(file) - if (checkManifestForR(jar)) { - print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) - val rSource = extractRFolder(jar, printStream, verbose) - if (RUtils.rPackages.isEmpty) { - RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) - } - try { - if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { - print(s"ERROR: Failed to build R package in $file.", printStream) - print(RJarDoc, printStream) + Utils.tryWithSafeFinally { + if (checkManifestForR(jar)) { + print(s"$file contains R source code. Now installing package.", printStream, Level.INFO) + val rSource = extractRFolder(jar, printStream, verbose) + if (RUtils.rPackages.isEmpty) { + RUtils.rPackages = Some(Utils.createTempDir().getAbsolutePath) } - } finally { // clean up - if (!rSource.delete()) { - logWarning(s"Error deleting ${rSource.getPath()}") + try { + if (!rPackageBuilder(rSource, printStream, verbose, RUtils.rPackages.get)) { + print(s"ERROR: Failed to build R package in $file.", printStream) + print(RJarDoc, printStream) + } + } finally { + // clean up + if (!rSource.delete()) { + logWarning(s"Error deleting ${rSource.getPath()}") + } + } + } else { + if (verbose) { + print(s"$file doesn't contain R source code, skipping...", printStream) } } - } else { - if (verbose) { - print(s"$file doesn't contain R source code, skipping...", printStream) - } + } { + jar.close() } } else { print(s"WARN: $file resolved as dependency, but not found.", printStream, Level.WARNING) @@ -231,8 +236,12 @@ private[deploy] object RPackageUtils extends Logging { val zipOutputStream = new ZipOutputStream(new FileOutputStream(zipFile, false)) try { filesToBundle.foreach { file => - // get the relative paths for proper naming in the zip file - val relPath = file.getAbsolutePath.replaceFirst(dir.getAbsolutePath, "") + // Get the relative paths for proper naming in the ZIP file. Note that + // we convert dir to URI to force / and then remove trailing / that show up for + // directories because the separator should always be / for according to ZIP + // specification and therefore `relPath` here should be, for example, + // "/packageTest/def.R" or "/test.R". + val relPath = file.toURI.toString.replaceFirst(dir.toURI.toString.stripSuffix("/"), "") val fis = new FileInputStream(file) val zipEntry = new ZipEntry(relPath) zipOutputStream.putNextEntry(zipEntry) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 3f54ecc17ac33..9cc321af4bde2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -18,10 +18,9 @@ package org.apache.spark.deploy import java.io.IOException -import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import java.text.DateFormat -import java.util.{Arrays, Comparator, Date} +import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -29,7 +28,7 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} -import org.apache.hadoop.fs.FileSystem.Statistics +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -84,17 +83,20 @@ class SparkHadoopUtil extends Logging { // the behavior of the old implementation of this code, for backwards compatibility. if (conf != null) { // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && - System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - val keyId = System.getenv("AWS_ACCESS_KEY_ID") - val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") - + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + if (keyId != null && accessKey != null) { hadoopConf.set("fs.s3.awsAccessKeyId", keyId) hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) hadoopConf.set("fs.s3a.access.key", keyId) hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3a.secret.key", accessKey) + + val sessionToken = System.getenv("AWS_SESSION_TOKEN") + if (sessionToken != null) { + hadoopConf.set("fs.s3a.session.token", sessionToken) + } } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" conf.getAll.foreach { case (key, value) => @@ -140,54 +142,29 @@ class SparkHadoopUtil extends Logging { /** * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will - * return the bytes read on r since t. Reflection is required because thread-level FileSystem - * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). - * Returns None if the required method can't be found. + * return the bytes read on r since t. + * + * @return None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { - try { - val threadStats = getFileSystemThreadStatistics() - val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") - val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum - val baselineBytesRead = f() - Some(() => f() - baselineBytesRead) - } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => - logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) - None - } + private[spark] def getFSBytesReadOnThreadCallback(): () => Long = { + val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics) + val f = () => threadStats.map(_.getBytesRead).sum + val baselineBytesRead = f() + () => f() - baselineBytesRead } /** * Returns a function that can be called to find Hadoop FileSystem bytes written. If * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will - * return the bytes written on r since t. Reflection is required because thread-level FileSystem - * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). - * Returns None if the required method can't be found. + * return the bytes written on r since t. + * + * @return None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { - try { - val threadStats = getFileSystemThreadStatistics() - val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") - val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum - val baselineBytesWritten = f() - Some(() => f() - baselineBytesWritten) - } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => - logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) - None - } - } - - private def getFileSystemThreadStatistics(): Seq[AnyRef] = { - FileSystem.getAllStatistics.asScala.map( - Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) - } - - private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { - val statisticsDataClass = - Utils.classForName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") - statisticsDataClass.getDeclaredMethod(methodName) + private[spark] def getFSBytesWrittenOnThreadCallback(): () => Long = { + val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics) + val f = () => threadStats.map(_.getBytesWritten).sum + val baselineBytesWritten = f() + () => f() - baselineBytesWritten } /** @@ -357,7 +334,7 @@ class SparkHadoopUtil extends Logging { * @return a printable string value. */ private[spark] def tokenToString(token: Token[_ <: TokenIdentifier]): String = { - val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT) + val df = DateFormat.getDateTimeInstance(DateFormat.SHORT, DateFormat.SHORT, Locale.US) val buffer = new StringBuilder(128) buffer.append(token.toString) try { @@ -373,10 +350,32 @@ class SparkHadoopUtil extends Logging { } } catch { case e: IOException => - logDebug("Failed to decode $token: $e", e) + logDebug(s"Failed to decode $token: $e", e) } buffer.toString } + + private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { + val perm = status.getPermission + val ugi = UserGroupInformation.getCurrentUser + + if (ugi.getShortUserName == status.getOwner) { + if (perm.getUserAction.implies(mode)) { + return true + } + } else if (ugi.getGroupNames.contains(status.getGroup)) { + if (perm.getGroupAction.implies(mode)) { + return true + } + } else if (perm.getOtherAction.implies(mode)) { + return true + } + + logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + + s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + + s"${if (status.isDirectory) "d" else "-"}$perm") + false + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5c052286099f5..77005aa9040b5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy -import java.io.{File, PrintStream} +import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL import java.security.PrivilegedExceptionAction +import java.text.ParseException import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -41,12 +42,11 @@ import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException} -import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL} +import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.util._ /** * Whether to submit, kill, or request the status of an application. @@ -63,7 +63,7 @@ private[deploy] object SparkSubmitAction extends Enumeration { * This program handles setting up the classpath with relevant Spark dependencies and provides * a layer over the different cluster managers and deploy modes that Spark supports. */ -object SparkSubmit { +object SparkSubmit extends CommandLineUtils { // Cluster managers private val YARN = 1 @@ -87,15 +87,6 @@ object SparkSubmit { private val CLASS_NOT_FOUND_EXIT_STATUS = 101 // scalastyle:off println - // Exposed for testing - private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) - private[spark] var printStream: PrintStream = System.err - private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) - private[spark] def printErrorAndExit(str: String): Unit = { - printStream.println("Error: " + str) - printStream.println("Run with --help for usage help or --verbose for debug output") - exitFn(1) - } private[spark] def printVersionAndExit(): Unit = { printStream.println("""Welcome to ____ __ @@ -115,7 +106,7 @@ object SparkSubmit { } // scalastyle:on println - def main(args: Array[String]): Unit = { + override def main(args: Array[String]): Unit = { val appArgs = new SparkSubmitArguments(args) if (appArgs.verbose) { // scalastyle:off println @@ -293,8 +284,17 @@ object SparkSubmit { } else { Nil } + + // Create the IvySettings, either load from file or build defaults + val ivySettings = args.sparkProperties.get("spark.jars.ivySettings").map { ivySettingsFile => + SparkSubmitUtils.loadIvySettings(ivySettingsFile, Option(args.repositories), + Option(args.ivyRepoPath)) + }.getOrElse { + SparkSubmitUtils.buildIvySettings(Option(args.repositories), Option(args.ivyRepoPath)) + } + val resolvedMavenCoordinates = SparkSubmitUtils.resolveMavenCoordinates(args.packages, - Option(args.repositories), Option(args.ivyRepoPath), exclusions = exclusions) + ivySettings, exclusions = exclusions) if (!StringUtils.isBlank(resolvedMavenCoordinates)) { args.jars = mergeFileLists(args.jars, resolvedMavenCoordinates) if (args.isPython) { @@ -322,7 +322,7 @@ object SparkSubmit { } // Require all R files to be local - if (args.isR && !isYarnCluster) { + if (args.isR && !isYarnCluster && !isMesosCluster) { if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") } @@ -330,9 +330,6 @@ object SparkSubmit { // The following modes are not supported or applicable (clusterManager, deployMode) match { - case (MESOS, CLUSTER) if args.isR => - printErrorAndExit("Cluster deploy mode is currently not supported for R " + - "applications on Mesos clusters.") case (STANDALONE, CLUSTER) if args.isPython => printErrorAndExit("Cluster deploy mode is currently not supported for python " + "applications on standalone clusters.") @@ -410,9 +407,9 @@ object SparkSubmit { printErrorAndExit("Distributing R packages with standalone cluster is not supported.") } - // TODO: Support SparkR with mesos cluster - if (args.isR && clusterManager == MESOS) { - printErrorAndExit("SparkR is not supported for Mesos cluster.") + // TODO: Support distributing R packages with mesos cluster + if (args.isR && clusterManager == MESOS && !RUtils.rPackages.isEmpty) { + printErrorAndExit("Distributing R packages with mesos cluster is not supported.") } // If we're running an R app, set the main class to our specific R runner @@ -488,12 +485,17 @@ object SparkSubmit { // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath - if (deployMode == CLIENT) { + // Also add the main application jar and any added jars to classpath in case YARN client + // requires these jars. + if (deployMode == CLIENT || isYarnCluster) { childMainClass = args.mainClass if (isUserJar(args.primaryResource)) { childClasspath += args.primaryResource } if (args.jars != null) { childClasspath ++= args.jars.split(",") } + } + + if (deployMode == CLIENT) { if (args.childArgs != null) { childArgs ++= args.childArgs } } @@ -598,6 +600,9 @@ object SparkSubmit { if (args.pyFiles != null) { sysProps("spark.submit.pyFiles") = args.pyFiles } + } else if (args.isR) { + // Second argument is main class + childArgs += (args.primaryResource, "") } else { childArgs += (args.primaryResource, args.mainClass) } @@ -665,7 +670,8 @@ object SparkSubmit { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") - printStream.println(s"System properties:\n${sysProps.mkString("\n")}") + // sysProps may contain sensitive information, so redact before printing + printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } @@ -870,30 +876,13 @@ private[spark] object SparkSubmitUtils { /** * Extracts maven coordinates from a comma-delimited string - * @param remoteRepos Comma-delimited string of remote repositories - * @param ivySettings The Ivy settings for this session + * @param defaultIvyUserDir The default user path for Ivy * @return A ChainResolver used by Ivy to search for and resolve dependencies. */ - def createRepoResolvers(remoteRepos: Option[String], ivySettings: IvySettings): ChainResolver = { + def createRepoResolvers(defaultIvyUserDir: File): ChainResolver = { // We need a chain resolver if we want to check multiple repositories val cr = new ChainResolver - cr.setName("list") - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - // scalastyle:off println - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - // scalastyle:on println - } - } + cr.setName("spark-list") val localM2 = new IBiblioResolver localM2.setM2compatible(true) @@ -903,7 +892,7 @@ private[spark] object SparkSubmitUtils { cr.add(localM2) val localIvy = new FileSystemResolver - val localIvyRoot = new File(ivySettings.getDefaultIvyUserDir, "local") + val localIvyRoot = new File(defaultIvyUserDir, "local") localIvy.setLocal(true) localIvy.setRepository(new FileRepository(localIvyRoot)) val ivyPattern = Seq(localIvyRoot.getAbsolutePath, "[organisation]", "[module]", "[revision]", @@ -984,6 +973,87 @@ private[spark] object SparkSubmitUtils { } } + /** + * Build Ivy Settings using options with default resolvers + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return An IvySettings object + */ + def buildIvySettings(remoteRepos: Option[String], ivyPath: Option[String]): IvySettings = { + val ivySettings: IvySettings = new IvySettings + processIvyPathArg(ivySettings, ivyPath) + + // create a pattern matcher + ivySettings.addMatcher(new GlobPatternMatcher) + // create the dependency resolvers + val repoResolver = createRepoResolvers(ivySettings.getDefaultIvyUserDir) + ivySettings.addResolver(repoResolver) + ivySettings.setDefaultResolver(repoResolver.getName) + processRemoteRepoArg(ivySettings, remoteRepos) + ivySettings + } + + /** + * Load Ivy settings from a given filename, using supplied resolvers + * @param settingsFile Path to Ivy settings file + * @param remoteRepos Comma-delimited string of remote repositories other than maven central + * @param ivyPath The path to the local ivy repository + * @return An IvySettings object + */ + def loadIvySettings( + settingsFile: String, + remoteRepos: Option[String], + ivyPath: Option[String]): IvySettings = { + val file = new File(settingsFile) + require(file.exists(), s"Ivy settings file $file does not exist") + require(file.isFile(), s"Ivy settings file $file is not a normal file") + val ivySettings: IvySettings = new IvySettings + try { + ivySettings.load(file) + } catch { + case e @ (_: IOException | _: ParseException) => + throw new SparkException(s"Failed when loading Ivy settings from $settingsFile", e) + } + processIvyPathArg(ivySettings, ivyPath) + processRemoteRepoArg(ivySettings, remoteRepos) + ivySettings + } + + /* Set ivy settings for location of cache, if option is supplied */ + private def processIvyPathArg(ivySettings: IvySettings, ivyPath: Option[String]): Unit = { + ivyPath.filterNot(_.trim.isEmpty).foreach { alternateIvyDir => + ivySettings.setDefaultIvyUserDir(new File(alternateIvyDir)) + ivySettings.setDefaultCache(new File(alternateIvyDir, "cache")) + } + } + + /* Add any optional additional remote repositories */ + private def processRemoteRepoArg(ivySettings: IvySettings, remoteRepos: Option[String]): Unit = { + remoteRepos.filterNot(_.trim.isEmpty).map(_.split(",")).foreach { repositoryList => + val cr = new ChainResolver + cr.setName("user-list") + + // add current default resolver, if any + Option(ivySettings.getDefaultResolver).foreach(cr.add) + + // add additional repositories, last resolution in chain takes precedence + repositoryList.zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + // scalastyle:off println + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + // scalastyle:on println + } + + ivySettings.addResolver(cr) + ivySettings.setDefaultResolver(cr.getName) + } + } + /** A nice function to use in tests as well. Values are dummy strings. */ def getModuleDescriptor: DefaultModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance( ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0")) @@ -991,16 +1061,14 @@ private[spark] object SparkSubmitUtils { /** * Resolves any dependencies that were supplied through maven coordinates * @param coordinates Comma-delimited string of maven coordinates - * @param remoteRepos Comma-delimited string of remote repositories other than maven central - * @param ivyPath The path to the local ivy repository + * @param ivySettings An IvySettings containing resolvers to use * @param exclusions Exclusions to apply when resolving transitive dependencies * @return The comma-delimited path to the jars of the given maven artifacts including their * transitive dependencies */ def resolveMavenCoordinates( coordinates: String, - remoteRepos: Option[String], - ivyPath: Option[String], + ivySettings: IvySettings, exclusions: Seq[String] = Nil, isTest: Boolean = false): String = { if (coordinates == null || coordinates.trim.isEmpty) { @@ -1011,32 +1079,14 @@ private[spark] object SparkSubmitUtils { // To prevent ivy from logging to system out System.setOut(printStream) val artifacts = extractMavenCoordinates(coordinates) - // Default configuration name for ivy - val ivyConfName = "default" - // set ivy settings for location of cache - val ivySettings: IvySettings = new IvySettings // Directories for caching downloads through ivy and storing the jars when maven coordinates // are supplied to spark-submit - val alternateIvyCache = ivyPath.getOrElse("") - val packagesDirectory: File = - if (alternateIvyCache == null || alternateIvyCache.trim.isEmpty) { - new File(ivySettings.getDefaultIvyUserDir, "jars") - } else { - ivySettings.setDefaultIvyUserDir(new File(alternateIvyCache)) - ivySettings.setDefaultCache(new File(alternateIvyCache, "cache")) - new File(alternateIvyCache, "jars") - } + val packagesDirectory: File = new File(ivySettings.getDefaultIvyUserDir, "jars") // scalastyle:off println printStream.println( s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}") printStream.println(s"The jars for the packages stored in: $packagesDirectory") // scalastyle:on println - // create a pattern matcher - ivySettings.addMatcher(new GlobPatternMatcher) - // create the dependency resolvers - val repoResolver = createRepoResolvers(remoteRepos, ivySettings) - ivySettings.addResolver(repoResolver) - ivySettings.setDefaultResolver(repoResolver.getName) val ivy = Ivy.newInstance(ivySettings) // Set resolve options to download transitive dependencies as well @@ -1052,6 +1102,9 @@ private[spark] object SparkSubmitUtils { resolveOptions.setDownload(true) } + // Default configuration name for ivy + val ivyConfName = "default" + // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor // clear ivy resolution from previous launches. The resolution file is usually at diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index f1761e7c1ec92..0144fd1056bac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -84,9 +84,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + val properties = Utils.getPropertiesFromFile(filename) + properties.foreach { case (k, v) => defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } + // Property files may contain sensitive information, so redact before printing + if (verbose) { + Utils.redact(properties).foreach { case (k, v) => + SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } } } // scalastyle:on println @@ -184,6 +190,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S .orNull numExecutors = Option(numExecutors) .getOrElse(sparkProperties.get("spark.executor.instances").orNull) + queue = Option(queue).orElse(sparkProperties.get("spark.yarn.queue")).orNull keytab = Option(keytab).orElse(sparkProperties.get("spark.yarn.keytab")).orNull principal = Option(principal).orElse(sparkProperties.get("spark.yarn.principal")).orNull @@ -318,7 +325,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | |Spark properties used, including those specified through | --conf and those from the properties file $propertiesFile: - |${sparkProperties.mkString(" ", "\n ", "\n")} + |${Utils.redact(sparkProperties).mkString(" ", "\n ", "\n")} """.stripMargin } @@ -412,10 +419,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S repositories = value case CONF => - value.split("=", 2).toSeq match { - case Seq(k, v) => sparkProperties(k) = v - case _ => SparkSubmit.printErrorAndExit(s"Spark config without '=': $value") - } + val (confName, confValue) = SparkSubmit.parseSparkConfProperty(value) + sparkProperties(confName) = confValue case PROXY_USER => proxyUser = value @@ -508,7 +513,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place | on the PYTHONPATH for Python apps. | --files FILES Comma-separated list of files to be placed in the working - | directory of each executor. + | directory of each executor. File paths of these files + | in executors can be accessed via SparkFiles.get(fileName). | | --conf PROP=VALUE Arbitrary Spark configuration property. | --properties-file FILE Path to a file from which to load extra properties. If not diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 06530ff836466..5cb48ca3e60b0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -30,7 +30,8 @@ private[spark] case class ApplicationAttemptInfo( endTime: Long, lastUpdated: Long, sparkUser: String, - completed: Boolean = false) + completed: Boolean = false, + appSparkVersion: String) private[spark] case class ApplicationHistoryInfo( id: String, @@ -74,6 +75,30 @@ private[history] case class LoadedAppUI( private[history] abstract class ApplicationHistoryProvider { + /** + * Returns the count of application event logs that the provider is currently still processing. + * History Server UI can use this to indicate to a user that the application listing on the UI + * can be expected to list additional known applications once the processing of these + * application event logs completes. + * + * A History Provider that does not have a notion of count of event logs that may be pending + * for processing need not override this method. + * + * @return Count of application event logs that are currently under process + */ + def getEventLogsUnderProcess(): Int = { + 0 + } + + /** + * Returns the time the history provider last updated the application history information + * + * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis + */ + def getLastUpdatedTime(): Long = { + 0 + } + /** * Returns a list of applications available for the history server to show. * diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index dfc1aad64c818..d05ca142b618b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{FileNotFoundException, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -27,7 +27,8 @@ import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -94,11 +95,17 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) private val logDir = conf.getOption("spark.history.fs.logDirectory") - .map { d => Utils.resolveURI(d).toString } .getOrElse(DEFAULT_LOG_DIR) + private val HISTORY_UI_ACLS_ENABLE = conf.getBoolean("spark.history.ui.acls.enable", false) + private val HISTORY_UI_ADMIN_ACLS = conf.get("spark.history.ui.admin.acls", "") + private val HISTORY_UI_ADMIN_ACLS_GROUPS = conf.get("spark.history.ui.admin.acls.groups", "") + logInfo(s"History server ui acls " + (if (HISTORY_UI_ACLS_ENABLE) "enabled" else "disabled") + + "; users with admin permissions: " + HISTORY_UI_ADMIN_ACLS.toString + + "; groups with admin permissions" + HISTORY_UI_ADMIN_ACLS_GROUPS.toString) + private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private val fs = Utils.getHadoopFileSystem(logDir, hadoopConf) + private val fs = new Path(logDir).getFileSystem(hadoopConf) // Used by check event thread and clean log thread. // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs @@ -108,7 +115,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // The modification time of the newest log detected during the last scan. Currently only // used for logging msgs (logs are re-scanned based on file size, rather than modtime) - private var lastScanTime = -1L + private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted // into the map in order, so the LinkedHashMap maintains the correct ordering. @@ -120,6 +127,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -226,6 +235,10 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) applications.get(appId) } + override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() + + override def getLastUpdatedTime(): Long = lastScanTime.get() + override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { try { applications.get(appId).flatMap { appInfo => @@ -235,7 +248,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, - HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) + HistoryServer.getAttemptURI(appId, attempt.attemptId), + attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } @@ -244,13 +258,15 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) if (appListener.appId.isDefined) { - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) + ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") + ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui.getSecurityManager.setAdminAclsGroups(appListener.adminAclsGroups.getOrElse("")) + val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") + ui.getSecurityManager.setAdminAcls(adminAcls) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) + val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + appListener.adminAclsGroups.getOrElse("") + ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) Some(LoadedAppUI(ui, updateProbe(appId, attemptId, attempt.fileSize))) } else { @@ -305,21 +321,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - try { - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally - // reading a garbage file is safe, but we would log an error which can be scary to - // the end-user. - !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() - } catch { - case e: AccessControlException => - // Do not use "logInfo" since these messages can get pretty noisy if printed on - // every poll. - logDebug(s"No permission to read $entry, ignoring.") - false - } + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() && + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => @@ -329,26 +338,43 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") } - logInfos.map { file => - replayExecutor.submit(new Runnable { + + var tasks = mutable.ListBuffer[Future[_]]() + + try { + for (file <- logInfos) { + tasks += replayExecutor.submit(new Runnable { override def run(): Unit = mergeApplicationListing(file) }) } - .foreach { task => - try { - // Wait for all tasks to finish. This makes sure that checkForLogs - // is not scheduled again while some tasks are already running in - // the replayExecutor. - task.get() - } catch { - case e: InterruptedException => - throw e - case e: Exception => - logError("Exception while merging application listings", e) - } + } catch { + // let the iteration over logInfos break, since an exception on + // replayExecutor.submit (..) indicates the ExecutorService is unable + // to take any more submissions at this time + + case e: Exception => + logError(s"Exception while submitting event log for replay", e) + } + + pendingReplayTasksCount.addAndGet(tasks.size) + + tasks.foreach { task => + try { + // Wait for all tasks to finish. This makes sure that checkForLogs + // is not scheduled again while some tasks are already running in + // the replayExecutor. + task.get() + } catch { + case e: InterruptedException => + throw e + case e: Exception => + logError("Exception while merging application listings", e) + } finally { + pendingReplayTasksCount.decrementAndGet() } + } - lastScanTime = newLastScanTime + lastScanTime.set(newLastScanTime) } catch { case e: Exception => logError("Exception in checking for event log updates", e) } @@ -365,7 +391,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } catch { case e: Exception => logError("Exception encountered when attempting to update last scan time", e) - lastScanTime + lastScanTime.get() } finally { if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") @@ -415,17 +441,22 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the log files in the list and merge the list of old applications with new ones */ - private def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || - eventString.startsWith(APPL_END_EVENT_PREFIX) + eventString.startsWith(APPL_END_EVENT_PREFIX) || + eventString.startsWith(LOG_START_EVENT_PREFIX) } val logPath = fileStatus.getPath() - val appCompleted = isApplicationCompleted(fileStatus) + // Use loading time as lastUpdated since some filesystems don't update modifiedTime + // each time file is updated. However use modifiedTime for completed jobs so lastUpdated + // won't change whenever HistoryServer restarts and reloads the file. + val lastUpdated = if (appCompleted) fileStatus.getModificationTime else clock.getTimeMillis() + val appListener = replay(fileStatus, appCompleted, new ReplayListenerBus(), eventsFilter) // Without an app ID, new logs will render incorrectly in the listing page, so do not list or @@ -438,10 +469,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) appListener.appAttemptId, appListener.startTime.getOrElse(-1L), appListener.endTime.getOrElse(-1L), - fileStatus.getModificationTime(), + lastUpdated, appListener.sparkUser.getOrElse(NOT_STARTED), appCompleted, - fileStatus.getLen() + fileStatus.getLen(), + appListener.appSparkVersion.getOrElse("") ) fileToAppInfo(logPath) = attemptInfo logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") @@ -523,7 +555,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() def shouldClean(attempt: FsApplicationAttemptInfo): Boolean = { - now - attempt.lastUpdated > maxAge && attempt.completed + now - attempt.lastUpdated > maxAge } // Scan all logs from the log directory. @@ -640,9 +672,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) false } - // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) + /* true to check only for Active NNs status */ + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET, true) } /** @@ -707,6 +739,8 @@ private[history] object FsHistoryProvider { private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" + + private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" } /** @@ -734,9 +768,10 @@ private class FsApplicationAttemptInfo( lastUpdated: Long, sparkUser: String, completed: Boolean, - val fileSize: Long) + val fileSize: Long, + appSparkVersion: String) extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed) { + attemptId, startTime, endTime, lastUpdated, sparkUser, completed, appSparkVersion) { /** extend the superclass string value with the extra attributes of this class */ override def toString: String = { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 96b9ecf43b14c..af14717633409 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -26,17 +26,35 @@ import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { def render(request: HttpServletRequest): Seq[Node] = { + // stripXSS is called first to remove suspicious characters used in XSS attacks val requestedIncomplete = - Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean + Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete) + val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess() + val lastUpdatedTime = parent.getLastUpdatedTime() val providerConfig = parent.getProviderConfig() val content = +
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
+ { + if (eventLogsUnderProcessCount > 0) { +

There are {eventLogsUnderProcessCount} event log(s) currently being + processed which may result in additional applications getting listed on this page. + Refresh the page to view updates.

+ } + } + + { + if (lastUpdatedTime > 0) { +

Last updated: {lastUpdatedTime}

+ } + } + { if (allAppsSize > 0) { ++ @@ -46,6 +64,8 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") } else if (requestedIncomplete) {

No incomplete applications found!

+ } else if (eventLogsUnderProcessCount > 0) { +

No completed applications found!

} else {

No completed applications found!

++ parent.emptyListingHtml } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 3175b36b3e56f..d9c8fda99ef97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -179,6 +179,14 @@ class HistoryServer( provider.getListing() } + def getEventLogsUnderProcess(): Int = { + provider.getEventLogsUnderProcess() + } + + def getLastUpdatedTime(): Long = { + provider.getLastUpdatedTime() + } + def getApplicationInfoList: Iterator[ApplicationInfo] = { getApplicationList().map(ApplicationsListResource.appHistoryInfoToPublicAppInfo) } @@ -261,7 +269,7 @@ object HistoryServer extends Logging { Utils.initDaemon(log) new HistoryServerArguments(conf, argStrings) initSecurity() - val securityManager = new SecurityManager(conf) + val securityManager = createSecurityManager(conf) val providerName = conf.getOption("spark.history.provider") .getOrElse(classOf[FsHistoryProvider].getName()) @@ -281,6 +289,29 @@ object HistoryServer extends Logging { while(true) { Thread.sleep(Int.MaxValue) } } + /** + * Create a security manager. + * This turns off security in the SecurityManager, so that the History Server can start + * in a Spark cluster where security is enabled. + * @param config configuration for the SecurityManager constructor + * @return the security manager for use in constructing the History Server. + */ + private[history] def createSecurityManager(config: SparkConf): SecurityManager = { + if (config.getBoolean(SecurityManager.SPARK_AUTH_CONF, false)) { + logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") + config.set(SecurityManager.SPARK_AUTH_CONF, "false") + } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + + new SecurityManager(config) + } + def initSecurity() { // If we are accessing HDFS and it has security enabled (Kerberos), we have to login // from a keytab file so that we can access HDFS beyond the kerberos ticket expiration. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 2eddb5ff54479..080ba12c2f0d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -24,7 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Command-line parser for the master. + * Command-line parser for the [[HistoryServer]]. */ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8c91aa15167c4..e061939623cbb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.master import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -51,7 +51,8 @@ private[deploy] class Master( private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) @@ -79,7 +80,7 @@ private[deploy] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(address.host, "Expected hostname") + Utils.checkHost(address.host) private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, @@ -230,6 +231,29 @@ private[deploy] class Master( logError("Leadership has been revoked -- master shutting down.") System.exit(0) + case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => + logInfo("Registering worker %s:%d with %d cores, %s RAM".format( + workerHost, workerPort, cores, Utils.megabytesToString(memory))) + if (state == RecoveryState.STANDBY) { + workerRef.send(MasterInStandby) + } else if (idToWorker.contains(id)) { + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) + } else { + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + workerRef, workerWebUiUrl) + if (registerWorker(worker)) { + persistenceEngine.addWorker(worker) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) + schedule() + } else { + val workerAddress = worker.endpoint.address + logWarning("Worker registration failed. Attempted to re-register worker at same " + + "address: " + workerAddress) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) + } + } + case RegisterApplication(description, driver) => // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { @@ -385,30 +409,6 @@ private[deploy] class Master( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => - logInfo("Registering worker %s:%d with %d cores, %s RAM".format( - workerHost, workerPort, cores, Utils.megabytesToString(memory))) - if (state == RecoveryState.STANDBY) { - context.reply(MasterInStandby) - } else if (idToWorker.contains(id)) { - context.reply(RegisterWorkerFailed("Duplicate worker ID")) - } else { - val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - workerRef, workerWebUiUrl) - if (registerWorker(worker)) { - persistenceEngine.addWorker(worker) - context.reply(RegisteredWorker(self, masterWebUiUrl)) - schedule() - } else { - val workerAddress = worker.endpoint.address - logWarning("Worker registration failed. Attempted to re-register worker at same " + - "address: " + workerAddress) - context.reply(RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress)) - } - } - case RequestSubmitDriver(description) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + @@ -1045,7 +1045,7 @@ private[deploy] object Master extends Logging { val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) - val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + val portsResponse = masterEndpoint.askSync[BoundPortsResponse](BoundPortsRequest) (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index c63793c16dcef..615d2533cf085 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -60,12 +60,12 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 4e20c10fd1427..c87d6e24b78c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -32,7 +32,7 @@ private[spark] class WorkerInfo( val webUiAddress: String) extends Serializable { - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 18cff3125d6b4..f40896457df95 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -33,8 +33,9 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { - val appId = request.getParameter("appId") - val state = master.askWithRetry[MasterStateResponse](RequestMasterState) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val appId = UIUtils.stripXSS(request.getParameter("appId")) + val state = master.askSync[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId) .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { @@ -83,7 +84,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} -
  • Submit Date: {app.submitDate}
  • +
  • Submit Date: {UIUtils.formatDate(app.submitDate)}
  • State: {app.state}
  • { if (!app.isFinished) { @@ -99,11 +100,11 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
    -

    Executor Summary

    +

    Executor Summary ({allExecutors.length})

    {executorsTable} { if (removedExecutors.nonEmpty) { -

    Removed Executors

    ++ +

    Removed Executors ({removedExecutors.length})

    ++ removedExecutorsTable } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 3fb860582cc17..bc0bf6a1d9700 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -33,7 +33,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - master.askWithRetry[MasterStateResponse](RequestMasterState) + master.askSync[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { if (parent.killEnabled && parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val id = Option(request.getParameter("id")) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val killFlag = + Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean + val id = Option(UIUtils.stripXSS(request.getParameter("id"))) if (id.isDefined && killFlag) { action(id.get) } @@ -76,7 +78,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Submitted Time", "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) @@ -126,14 +128,14 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Workers

    +

    Workers ({workers.length})

    {workerTable}
    -

    Running Applications

    +

    Running Applications ({activeApps.length})

    {activeAppsTable}
    @@ -142,7 +144,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {if (hasDrivers) {
    -

    Running Drivers

    +

    Running Drivers ({activeDrivers.length})

    {activeDriversTable}
    @@ -152,7 +154,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    -

    Completed Applications

    +

    Completed Applications ({completedApps.length})

    {completedAppsTable}
    @@ -162,7 +164,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { if (hasDrivers) {
    -

    Completed Drivers

    +

    Completed Drivers ({completedDrivers.length})

    {completedDriversTable}
    @@ -176,8 +178,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { private def workerRow(worker: WorkerInfo): Seq[Node] = { - {worker.id} + { + if (worker.isAlive()) { + + {worker.id} + + } else { + worker.id + } + } {worker.host}:{worker.port} {worker.state} @@ -245,12 +254,15 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } {driver.id} {killLink} - {driver.submitDate} + {UIUtils.formatDate(driver.submitDate)} {driver.worker.map(w => - - {w.id.toString} - ).getOrElse("None")} + if (w.isAlive()) { + + {w.id.toString} + + } else { + w.id.toString + }).getOrElse("None")} {driver.state} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index c19296c7b3e00..56620064c57fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -71,7 +71,7 @@ private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + val response = masterEndpoint.askSync[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion @@ -89,7 +89,7 @@ private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRe extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + val response = masterEndpoint.askSync[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse @@ -174,7 +174,7 @@ private[rest] class StandaloneSubmitRequestServlet( requestMessage match { case submitRequest: CreateSubmissionRequest => val driverDescription = buildDriverDescription(submitRequest) - val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + val response = masterEndpoint.askSync[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 0bedd9a20a969..34e3a4c020c80 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,13 +20,13 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{Date, UUID} +import java.util.{Date, Locale, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.mutable.{HashMap, HashSet, LinkedHashMap} import scala.concurrent.ExecutionContext -import scala.util.{Failure, Random, Success} +import scala.util.Random import scala.util.control.NonFatal import org.apache.spark.{SecurityManager, SparkConf} @@ -55,20 +55,20 @@ private[deploy] class Worker( private val host = rpcEnv.address.host private val port = rpcEnv.address.port - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(host) assert (port > 0) // A scheduled executor used to send messages at the specified time. private val forwordMessageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") - // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` - // methods. + // A separated thread to clean up the workDir and the directories of finished applications. + // Used to provide the implicit parameter of `Future` methods. private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) // For worker and executor IDs - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -187,8 +187,7 @@ private[deploy] class Worker( webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() - val scheme = if (webUi.sslOptions.enabled) "https" else "http" - workerWebUiUrl = s"$scheme://$publicAddress:${webUi.boundPort}" + workerWebUiUrl = s"http://$publicAddress:${webUi.boundPort}" registerWithMaster() metricsSystem.registerSource(workerSource) @@ -217,7 +216,7 @@ private[deploy] class Worker( try { logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) - registerWithMaster(masterEndpoint) + sendRegisterMessageToMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -273,7 +272,7 @@ private[deploy] class Worker( try { logInfo("Connecting to master " + masterAddress + "...") val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) - registerWithMaster(masterEndpoint) + sendRegisterMessageToMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) @@ -342,19 +341,8 @@ private[deploy] class Worker( } } - private def registerWithMaster(masterEndpoint: RpcEndpointRef): Unit = { - masterEndpoint.ask[RegisterWorkerResponse](RegisterWorker( - workerId, host, port, self, cores, memory, workerWebUiUrl)) - .onComplete { - // This is a very fast action so we can use "ThreadUtils.sameThread" - case Success(msg) => - Utils.tryLogNonFatalError { - handleRegisterResponse(msg) - } - case Failure(e) => - logError(s"Cannot register with master: ${masterEndpoint.address}", e) - System.exit(1) - }(ThreadUtils.sameThread) + private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = { + masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl)) } private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized { @@ -395,6 +383,9 @@ private[deploy] class Worker( } override def receive: PartialFunction[Any, Unit] = synchronized { + case msg: RegisterWorkerResponse => + handleRegisterResponse(msg) + case SendHeartbeat => if (connected) { sendToMaster(Heartbeat(workerId, self)) } @@ -454,12 +445,25 @@ private[deploy] class Worker( // Create local dirs for the executor. These are passed to the executor via the // SPARK_EXECUTOR_DIRS environment variable, and deleted by the Worker when the // application finishes. - val appLocalDirs = appDirectories.getOrElse(appId, - Utils.getOrCreateLocalRootDirs(conf).map { dir => - val appDir = Utils.createDirectory(dir, namePrefix = "executor") - Utils.chmod700(appDir) - appDir.getAbsolutePath() - }.toSeq) + val appLocalDirs = appDirectories.getOrElse(appId, { + val localRootDirs = Utils.getOrCreateLocalRootDirs(conf) + val dirs = localRootDirs.flatMap { dir => + try { + val appDir = Utils.createDirectory(dir, namePrefix = "executor") + Utils.chmod700(appDir) + Some(appDir.getAbsolutePath()) + } catch { + case e: IOException => + logWarning(s"${e.getMessage}. Ignoring this directory.") + None + } + }.toSeq + if (dirs.isEmpty) { + throw new IOException("No subfolder can be created in " + + s"${localRootDirs.mkString(",")}.") + } + dirs + }) appDirectories(appId) = appLocalDirs val manager = new ExecutorRunner( appId, @@ -574,10 +578,15 @@ private[deploy] class Worker( if (shouldCleanup) { finishedApps -= id appDirectories.remove(id).foreach { dirList => - logInfo(s"Cleaning up local directories for application $id") - dirList.foreach { dir => - Utils.deleteRecursively(new File(dir)) - } + concurrent.Future { + logInfo(s"Cleaning up local directories for application $id") + dirList.foreach { dir => + Utils.deleteRecursively(new File(dir)) + } + }(cleanupThreadExecutor).onFailure { + case e: Throwable => + logError(s"Clean up app dir $dirList failed: ${e.getMessage}", e) + }(cleanupThreadExecutor) } shuffleService.applicationRemoved(id) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 777020d4d5c84..bd07d342e04ac 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -68,12 +68,12 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { @tailrec private def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => - Utils.checkHost(value, "ip no longer supported, please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) case ("--host" | "-h") :: value :: tail => - Utils.checkHost(value, "Please use hostname " + value) + Utils.checkHost(value) host = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 465c214362b25..2f5a5642d3cab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -22,8 +22,6 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} -import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} - import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils @@ -35,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val supportedLogTypes = Set("stderr", "stdout") private val defaultBytes = 100 * 1024 + // stripXSS is called first to remove suspicious characters used in XSS attacks def renderLog(request: HttpServletRequest): String = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => @@ -57,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with pre + logText } + // stripXSS is called first to remove suspicious characters used in XSS attacks def render(request: HttpServletRequest): Seq[Node] = { - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + val appId = Option(UIUtils.stripXSS(request.getParameter("appId"))) + val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId"))) + val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId"))) + val logType = UIUtils.stripXSS(request.getParameter("logType")) + val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong) + val byteLength = + Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt) + .getOrElse(defaultBytes) val (logDir, params, pageName) = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 8ebcbcb6a1738..1ad973122b609 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -34,12 +34,12 @@ private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) + val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) + val workerState = workerEndpoint.askSync[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 7eec4ae64f296..a2f1aa22b0063 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer +import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable @@ -72,7 +73,7 @@ private[spark] class CoarseGrainedExecutorBackend( def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) - .map(e => (e._1.substring(prefix.length).toLowerCase, e._2)) + .map(e => (e._1.substring(prefix.length).toLowerCase(Locale.ROOT), e._2)) } override def receive: PartialFunction[Any, Unit] = { @@ -92,17 +93,16 @@ private[spark] class CoarseGrainedExecutorBackend( if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { - val taskDesc = ser.deserialize[TaskDescription](data.value) + val taskDesc = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber, - taskDesc.name, taskDesc.serializedTask) + executor.launchTask(this, taskDesc) } - case KillTask(taskId, _, interruptThread) => + case KillTask(taskId, _, interruptThread, reason) => if (executor == null) { exitExecutor(1, "Received KillTask command but executor was null") } else { - executor.killTask(taskId, interruptThread) + executor.killTask(taskId, interruptThread, reason) } case StopExecutor => @@ -191,17 +191,16 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf - val port = executorConf.getInt("spark.executor.port", 0) val fetcher = RpcEnv.create( "driverPropsFetcher", hostname, - port, + -1, executorConf, new SecurityManager(executorConf), clientMode = true) val driver = fetcher.setupEndpointRefByURI(driverUrl) - val props = driver.askWithRetry[Seq[(String, String)]](RetrieveSparkProps) ++ - Seq[(String, String)](("spark.app.id", appId)) + val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -221,7 +220,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, port, cores, isLocal = false) + driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9501dd9cd8e93..3bc47b670305b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -18,23 +18,26 @@ package org.apache.spark.executor import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory -import java.net.URL +import java.net.{URI, URL} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal +import com.google.common.util.concurrent.ThreadFactoryBuilder + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rpc.RpcTimeout -import org.apache.spark.scheduler.{AccumulableInfo, DirectTaskResult, IndirectTaskResult, Task} +import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskDescription} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} import org.apache.spark.util._ @@ -52,7 +55,8 @@ private[spark] class Executor( executorHostname: String, env: SparkEnv, userClassPath: Seq[URL] = Nil, - isLocal: Boolean = false) + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -67,7 +71,7 @@ private[spark] class Executor( private val conf = env.conf // No ip or host:port - just hostname - Utils.checkHost(executorHostname, "Expected executed slave to be a hostname") + Utils.checkHost(executorHostname) // must not have port specified. assert (0 == Utils.parseHostPort(executorHostname)._2) @@ -78,12 +82,35 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) } // Start worker thread pool - private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory(new ThreadFactory { + override def newThread(r: Runnable): Thread = + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder + }) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } private val executorSource = new ExecutorSource(threadPool, executorId) + // Pool used for threads that supervise task killing / cancellation + private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") + // For tasks which are in the process of being killed, this map holds the most recently created + // TaskReaper. All accesses to this map should be synchronized on the map itself (this isn't + // a ConcurrentHashMap because we use the synchronization for purposes other than simply guarding + // the integrity of the map's internal state). The purpose of this map is to prevent the creation + // of a separate TaskReaper for every killTask() of a given task. Instead, this map allows us to + // track whether an existing TaskReaper fulfills the role of a TaskReaper that we would otherwise + // create. The map key is a task id. + private val taskReaperForTask: HashMap[Long, TaskReaper] = HashMap[Long, TaskReaper]() if (!isLocal) { env.metricsSystem.registerSource(executorSource) @@ -93,6 +120,9 @@ private[spark] class Executor( // Whether to load classes in user jars before those in Spark jars private val userClassPathFirst = conf.getBoolean("spark.executor.userClassPathFirst", false) + // Whether to monitor killed / interrupted tasks + private val taskReaperEnabled = conf.getBoolean("spark.task.reaper.enabled", false) + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -135,22 +165,37 @@ private[spark] class Executor( startDriverHeartbeater() - def launchTask( - context: ExecutorBackend, - taskId: Long, - attemptNumber: Int, - taskName: String, - serializedTask: ByteBuffer): Unit = { - val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName, - serializedTask) - runningTasks.put(taskId, tr) + private[executor] def numRunningTasks: Int = runningTasks.size() + + def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { + val tr = new TaskRunner(context, taskDescription) + runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) } - def killTask(taskId: Long, interruptThread: Boolean): Unit = { - val tr = runningTasks.get(taskId) - if (tr != null) { - tr.kill(interruptThread) + def killTask(taskId: Long, interruptThread: Boolean, reason: String): Unit = { + val taskRunner = runningTasks.get(taskId) + if (taskRunner != null) { + if (taskReaperEnabled) { + val maybeNewTaskReaper: Option[TaskReaper] = taskReaperForTask.synchronized { + val shouldCreateReaper = taskReaperForTask.get(taskId) match { + case None => true + case Some(existingReaper) => interruptThread && !existingReaper.interruptThread + } + if (shouldCreateReaper) { + val taskReaper = new TaskReaper( + taskRunner, interruptThread = interruptThread, reason = reason) + taskReaperForTask(taskId) = taskReaper + Some(taskReaper) + } else { + None + } + } + // Execute the TaskReaper from outside of the synchronized block. + maybeNewTaskReaper.foreach(taskReaperPool.execute) + } else { + taskRunner.kill(interruptThread = interruptThread, reason = reason) + } } } @@ -160,13 +205,9 @@ private[spark] class Executor( * tasks instead of taking the JVM down. * @param interruptThread whether to interrupt the task thread */ - def killAllTasks(interruptThread: Boolean) : Unit = { - // kill all the running tasks - for (taskRunner <- runningTasks.values().asScala) { - if (taskRunner != null) { - taskRunner.kill(interruptThread) - } - } + def killAllTasks(interruptThread: Boolean, reason: String) : Unit = { + runningTasks.keys().asScala.foreach(t => + killTask(t, interruptThread = interruptThread, reason = reason)) } def stop(): Unit = { @@ -186,19 +227,26 @@ private[spark] class Executor( class TaskRunner( execBackend: ExecutorBackend, - val taskId: Long, - val attemptNumber: Int, - taskName: String, - serializedTask: ByteBuffer) + private val taskDescription: TaskDescription) extends Runnable { - /** Whether this task has been killed. */ - @volatile private var killed = false + val taskId = taskDescription.taskId + val threadName = s"Executor task launch worker for task $taskId" + private val taskName = taskDescription.name + + /** If specified, this task has been killed and this option contains the reason. */ + @volatile private var reasonIfKilled: Option[String] = None + + @volatile private var threadId: Long = -1 + + def getThreadId: Long = threadId /** Whether this task has been finished. */ @GuardedBy("TaskRunner.this") private var finished = false + def isFinished: Boolean = synchronized { finished } + /** How much the JVM process has spent in GC when the task starts to run. */ @volatile var startGCTime: Long = _ @@ -208,13 +256,13 @@ private[spark] class Executor( */ @volatile var task: Task[Any] = _ - def kill(interruptThread: Boolean): Unit = { - logInfo(s"Executor is trying to kill $taskName (TID $taskId)") - killed = true + def kill(interruptThread: Boolean, reason: String): Unit = { + logInfo(s"Executor is trying to kill $taskName (TID $taskId), reason: $reason") + reasonIfKilled = Some(reason) if (task != null) { synchronized { if (!finished) { - task.kill(interruptThread) + task.kill(interruptThread, reason) } } } @@ -229,9 +277,15 @@ private[spark] class Executor( // ClosedByInterruptException during execBackend.statusUpdate which causes // Executor to crash Thread.interrupted() + // Notify any waiting TaskReapers. Generally there will only be one reaper per task but there + // is a rare corner-case where one task can have two reapers in case cancel(interrupt=False) + // is followed by cancel(interrupt=True). Thus we use notifyAll() to avoid a lost wakeup: + notifyAll() } override def run(): Unit = { + threadId = Thread.currentThread.getId + Thread.currentThread.setName(threadName) val threadMXBean = ManagementFactory.getThreadMXBean val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId) val deserializeStartTime = System.currentTimeMillis() @@ -247,26 +301,25 @@ private[spark] class Executor( startGCTime = computeTotalGcTime() try { - val (taskFiles, taskJars, taskProps, taskBytes) = - Task.deserializeWithDependencies(serializedTask) - // Must be set before updateDependencies() is called, in case fetching dependencies // requires access to properties contained within (e.g. for access control). - Executor.taskDeserializationProps.set(taskProps) + Executor.taskDeserializationProps.set(taskDescription.properties) - updateDependencies(taskFiles, taskJars) - task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) - task.localProperties = taskProps + updateDependencies(taskDescription.addedFiles, taskDescription.addedJars) + task = ser.deserialize[Task[Any]]( + taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) + task.localProperties = taskDescription.properties task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, // continue executing the task. - if (killed) { + val killReason = reasonIfKilled + if (killReason.isDefined) { // Throw an exception rather than returning, because returning within a try{} block // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl // exception will be caught by the catch block, leading to an incorrect ExceptionFailure // for the task. - throw new TaskKilledException + throw new TaskKilledException(killReason.get) } logDebug("Task " + taskId + "'s epoch is " + task.epoch) @@ -281,7 +334,7 @@ private[spark] class Executor( val value = try { val res = task.run( taskAttemptId = taskId, - attemptNumber = attemptNumber, + attemptNumber = taskDescription.attemptNumber, metricsSystem = env.metricsSystem) threwException = false res @@ -305,19 +358,25 @@ private[spark] class Executor( if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) { throw new SparkException(errMsg) } else { - logWarning(errMsg) + logInfo(errMsg) } } } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"TID ${taskId} completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) + } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime } else 0L // If the task has been killed, let's fail it. - if (task.killed) { - throw new TaskKilledException - } + task.context.killTaskIfInterrupted() val resultSer = env.serializer.newInstance() val beforeSerialization = System.currentTimeMillis() @@ -369,20 +428,32 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => - val reason = ffe.toTaskFailedReason + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") + } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) - case _: TaskKilledException => - logInfo(s"Executor killed $taskName (TID $taskId)") + case t: TaskKilledException => + logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case _: InterruptedException if task.killed => - logInfo(s"Executor interrupted and killed $taskName (TID $taskId)") + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => + val killReason = task.reasonIfKilled.getOrElse("unknown reason") + logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() - execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) + execBackend.statusUpdate( + taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason))) case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskFailedReason @@ -422,13 +493,129 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { runningTasks.remove(taskId) } } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } + } + + /** + * Supervises the killing / cancellation of a task by sending the interrupted flag, optionally + * sending a Thread.interrupt(), and monitoring the task until it finishes. + * + * Spark's current task cancellation / task killing mechanism is "best effort" because some tasks + * may not be interruptable or may not respond to their "killed" flags being set. If a significant + * fraction of a cluster's task slots are occupied by tasks that have been marked as killed but + * remain running then this can lead to a situation where new jobs and tasks are starved of + * resources that are being used by these zombie tasks. + * + * The TaskReaper was introduced in SPARK-18761 as a mechanism to monitor and clean up zombie + * tasks. For backwards-compatibility / backportability this component is disabled by default + * and must be explicitly enabled by setting `spark.task.reaper.enabled=true`. + * + * A TaskReaper is created for a particular task when that task is killed / cancelled. Typically + * a task will have only one TaskReaper, but it's possible for a task to have up to two reapers + * in case kill is called twice with different values for the `interrupt` parameter. + * + * Once created, a TaskReaper will run until its supervised task has finished running. If the + * TaskReaper has not been configured to kill the JVM after a timeout (i.e. if + * `spark.task.reaper.killTimeout < 0`) then this implies that the TaskReaper may run indefinitely + * if the supervised task never exits. + */ + private class TaskReaper( + taskRunner: TaskRunner, + val interruptThread: Boolean, + val reason: String) + extends Runnable { + + private[this] val taskId: Long = taskRunner.taskId + + private[this] val killPollingIntervalMs: Long = + conf.getTimeAsMs("spark.task.reaper.pollingInterval", "10s") + + private[this] val killTimeoutMs: Long = conf.getTimeAsMs("spark.task.reaper.killTimeout", "-1") + + private[this] val takeThreadDump: Boolean = + conf.getBoolean("spark.task.reaper.threadDump", true) + + override def run(): Unit = { + val startTimeMs = System.currentTimeMillis() + def elapsedTimeMs = System.currentTimeMillis() - startTimeMs + def timeoutExceeded(): Boolean = killTimeoutMs > 0 && elapsedTimeMs > killTimeoutMs + try { + // Only attempt to kill the task once. If interruptThread = false then a second kill + // attempt would be a no-op and if interruptThread = true then it may not be safe or + // effective to interrupt multiple times: + taskRunner.kill(interruptThread = interruptThread, reason = reason) + // Monitor the killed task until it exits. The synchronization logic here is complicated + // because we don't want to synchronize on the taskRunner while possibly taking a thread + // dump, but we also need to be careful to avoid races between checking whether the task + // has finished and wait()ing for it to finish. + var finished: Boolean = false + while (!finished && !timeoutExceeded()) { + taskRunner.synchronized { + // We need to synchronize on the TaskRunner while checking whether the task has + // finished in order to avoid a race where the task is marked as finished right after + // we check and before we call wait(). + if (taskRunner.isFinished) { + finished = true + } else { + taskRunner.wait(killPollingIntervalMs) + } + } + if (taskRunner.isFinished) { + finished = true + } else { + logWarning(s"Killed task $taskId is still running after $elapsedTimeMs ms") + if (takeThreadDump) { + try { + Utils.getThreadDumpForThread(taskRunner.getThreadId).foreach { thread => + if (thread.threadName == taskRunner.threadName) { + logWarning(s"Thread dump from task $taskId:\n${thread.stackTrace}") + } + } + } catch { + case NonFatal(e) => + logWarning("Exception thrown while obtaining thread dump: ", e) + } + } + } + } + + if (!taskRunner.isFinished && timeoutExceeded()) { + if (isLocal) { + logError(s"Killed task $taskId could not be stopped within $killTimeoutMs ms; " + + "not killing JVM because we are running in local mode.") + } else { + // In non-local-mode, the exception thrown here will bubble up to the uncaught exception + // handler and cause the executor JVM to exit. + throw new SparkException( + s"Killing executor JVM because killed task $taskId could not be stopped within " + + s"$killTimeoutMs ms.") + } + } + } finally { + // Clean up entries in the taskReaperForTask map. + taskReaperForTask.synchronized { + taskReaperForTask.get(taskId).foreach { taskReaperInMap => + if (taskReaperInMap eq this) { + taskReaperForTask.remove(taskId) + } else { + // This must have been a TaskReaper where interruptThread == false where a subsequent + // killTask() call for the same task had interruptThread == true and overwrote the + // map entry. + } + } + } + } + } } /** @@ -486,7 +673,7 @@ private[spark] class Executor( * Download any missing dependencies if we receive a new set of files and JARs from the * SparkContext. Also adds any new JARs we fetched to the class loader. */ - private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { + private def updateDependencies(newFiles: Map[String, Long], newJars: Map[String, Long]) { lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies @@ -498,7 +685,7 @@ private[spark] class Executor( currentFiles(name) = timestamp } for ((name, timestamp) <- newJars) { - val localName = name.split("/").last + val localName = new URI(name).getPath.split("/").last val currentTimeStamp = currentJars.get(name) .orElse(currentJars.get(localName)) .getOrElse(-1L) @@ -535,7 +722,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, accumUpdates.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) if (response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index f7a991770d402..8dd1a1ea059be 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -92,7 +92,7 @@ class ShuffleReadMetrics private[spark] () extends Serializable { private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) /** - * Resets the value of the current metrics (`this`) and and merges all the independent + * Resets the value of the current metrics (`this`) and merges all the independent * [[TempShuffleReadMetrics]] into `this`. */ private[spark] def setMergeValues(metrics: Seq[TempShuffleReadMetrics]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dfd2f818acdac..a3ce3d1ccc5e3 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - accumulators.find { acc => - acc.name.isDefined && acc.name.get == name - } + private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + internalAccums.filter(a => !a.isZero || a == _resultSize) } } @@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging { */ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics - val (internalAccums, externalAccums) = - accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) - - internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] - tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + for (acc <- accums) { + val name = acc.name + if (name.isDefined && tm.nameToAccums.contains(name.get)) { + val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] + tmAcc.metadata = acc.metadata + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + } else { + tm.externalAccums += acc + } } - - tm.externalAccums ++= externalAccums tm } } diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f66510b6f977f..9606c4754314f 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -27,6 +27,10 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} +import org.apache.spark.internal.config +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Since + /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -40,9 +44,14 @@ private[spark] abstract class StreamFileInputFormat[T] * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API * which is set through setMaxSplitSize */ - def setMinPartitions(context: JobContext, minPartitions: Int) { - val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum - val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong + def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int) { + val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultParallelism = sc.defaultParallelism + val files = listStatus(context).asScala + val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum + val bytesPerCore = totalBytes / defaultParallelism + val maxSplitSize = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore)) super.setMaxSplitSize(maxSplitSize) } @@ -167,6 +176,7 @@ class PortableDataStream( * Create a new DataInputStream from the split and context. The user of this method is responsible * for closing the stream after usage. */ + @Since("1.2.0") def open(): DataInputStream = { val pathp = split.getPath(index) val fs = pathp.getFileSystem(conf) @@ -176,6 +186,7 @@ class PortableDataStream( /** * Read the file as a byte array */ + @Since("1.2.0") def toArray(): Array[Byte] = { val stream = open() try { @@ -185,6 +196,10 @@ class PortableDataStream( } } + @Since("1.2.0") def getPath(): String = path + + @Since("2.2.0") + def getConfiguration: Configuration = conf } diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala index 013cd1c1bc037..c7f2847731fcb 100644 --- a/core/src/main/scala/org/apache/spark/internal/Logging.scala +++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala @@ -28,7 +28,7 @@ import org.apache.spark.util.Utils * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. */ -private[spark] trait Logging { +trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 0f5c8a9e02ab8..e5d60a7ef0984 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -18,6 +18,9 @@ package org.apache.spark.internal.config import java.util.concurrent.TimeUnit +import java.util.regex.PatternSyntaxException + +import scala.util.matching.Regex import org.apache.spark.network.util.{ByteUnit, JavaUtils} @@ -65,6 +68,13 @@ private object ConfigHelpers { def byteToString(v: Long, unit: ByteUnit): String = unit.convertTo(v, ByteUnit.BYTE) + "b" + def regexFromString(str: String, key: String): Regex = { + try str.r catch { + case e: PatternSyntaxException => + throw new IllegalArgumentException(s"$key should be a regex, but was $str", e) + } + } + } /** @@ -90,6 +100,14 @@ private[spark] class TypedConfigBuilder[T]( new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter) } + /** Checks if the user-provided value for the config matches the validator. */ + def checkValue(validator: T => Boolean, errorMsg: String): TypedConfigBuilder[T] = { + transform { v => + if (!validator(v)) throw new IllegalArgumentException(errorMsg) + v + } + } + /** Check that user-provided values for the config match a pre-defined set. */ def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = { transform { v => @@ -129,6 +147,14 @@ private[spark] class TypedConfigBuilder[T]( } } + /** Creates a [[ConfigEntry]] with a function to determine the default value */ + def createWithDefaultFunction(defaultFunc: () => T): ConfigEntry[T] = { + val entry = new ConfigEntryWithDefaultFunction[T](parent.key, defaultFunc, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_ (entry)) + entry + } + /** * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a * [[String]] and must be a valid value for the entry. @@ -206,4 +232,7 @@ private[spark] case class ConfigBuilder(key: String) { new FallbackConfigEntry(key, _doc, _public, fallback) } + def regexConf: TypedConfigBuilder[Regex] = { + new TypedConfigBuilder(this, regexFromString(_, this.key), _.toString) + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala index 113037d1ab5be..e86712e84d6ac 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala @@ -17,12 +17,6 @@ package org.apache.spark.internal.config -import java.util.{Map => JMap} - -import scala.util.matching.Regex - -import org.apache.spark.SparkConf - /** * An entry contains all meta information for a configuration. * @@ -34,7 +28,6 @@ import org.apache.spark.SparkConf * value declared as a string. * * @param key the key for the configuration - * @param defaultValue the default value for the configuration * @param valueConverter how to convert a string to the value. It should throw an exception if the * string does not have the required format. * @param stringConverter how to convert a value to a string that the user can use it as a valid @@ -76,7 +69,7 @@ private class ConfigEntryWithDefault[T] ( stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(_defaultValue) @@ -85,7 +78,24 @@ private class ConfigEntryWithDefault[T] ( def readFrom(reader: ConfigReader): T = { reader.get(key).map(valueConverter).getOrElse(_defaultValue) } +} + +private class ConfigEntryWithDefaultFunction[T] ( + key: String, + _defaultFunction: () => T, + valueConverter: String => T, + stringConverter: T => String, + doc: String, + isPublic: Boolean) + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + override def defaultValue: Option[T] = Some(_defaultFunction()) + + override def defaultValueString: String = stringConverter(_defaultFunction()) + + def readFrom(reader: ConfigReader): T = { + reader.get(key).map(valueConverter).getOrElse(_defaultFunction()) + } } private class ConfigEntryWithDefaultString[T] ( @@ -95,7 +105,7 @@ private class ConfigEntryWithDefaultString[T] ( stringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { + extends ConfigEntry(key, valueConverter, stringConverter, doc, isPublic) { override def defaultValue: Option[T] = Some(valueConverter(_defaultValue)) @@ -118,8 +128,8 @@ private[spark] class OptionalConfigEntry[T]( val rawStringConverter: T => String, doc: String, isPublic: Boolean) - extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)), - v => v.map(rawStringConverter).orNull, doc, isPublic) { + extends ConfigEntry[Option[T]](key, s => Some(rawValueConverter(s)), + v => v.map(rawStringConverter).orNull, doc, isPublic) { override def defaultValueString: String = "" @@ -137,7 +147,7 @@ private class FallbackConfigEntry[T] ( doc: String, isPublic: Boolean, private[config] val fallback: ConfigEntry[T]) - extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { + extends ConfigEntry[T](key, fallback.valueConverter, fallback.stringConverter, doc, isPublic) { override def defaultValueString: String = s"" diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala index bb1a3bb5fc56f..c62de9bfd8fc3 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigReader.scala @@ -18,7 +18,6 @@ package org.apache.spark.internal.config import java.util.{Map => JMap} -import java.util.regex.Pattern import scala.collection.mutable.HashMap import scala.util.matching.Regex diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 497ca92c7bc60..7f7921d56f49e 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -114,11 +114,21 @@ package object config { .intConf .createWithDefault(2) + private[spark] val MAX_FAILURES_PER_EXEC = + ConfigBuilder("spark.blacklist.application.maxFailedTasksPerExecutor") + .intConf + .createWithDefault(2) + private[spark] val MAX_FAILURES_PER_EXEC_STAGE = ConfigBuilder("spark.blacklist.stage.maxFailedTasksPerExecutor") .intConf .createWithDefault(2) + private[spark] val MAX_FAILED_EXEC_PER_NODE = + ConfigBuilder("spark.blacklist.application.maxFailedExecutorsPerNode") + .intConf + .createWithDefault(2) + private[spark] val MAX_FAILED_EXEC_PER_NODE_STAGE = ConfigBuilder("spark.blacklist.stage.maxFailedExecutorsPerNode") .intConf @@ -129,6 +139,11 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createOptional + private[spark] val BLACKLIST_KILL_ENABLED = + ConfigBuilder("spark.blacklist.killBlacklistedExecutors") + .booleanConf + .createWithDefault(false) + private[spark] val BLACKLIST_LEGACY_TIMEOUT_CONF = ConfigBuilder("spark.scheduler.executorTaskBlacklistTime") .internal() @@ -198,12 +213,69 @@ package object config { .createWithDefault(0) private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port") - .doc("Port to use for the block managed on the driver.") + .doc("Port to use for the block manager on the driver.") .fallbackConf(BLOCK_MANAGER_PORT) private[spark] val IGNORE_CORRUPT_FILES = ConfigBuilder("spark.files.ignoreCorruptFiles") .doc("Whether to ignore corrupt files. If true, the Spark jobs will continue to run when " + - "encountering corrupt files and contents that have been read will still be returned.") + "encountering corrupted or non-existing files and contents that have been read will still " + + "be returned.") .booleanConf .createWithDefault(false) + + private[spark] val APP_CALLER_CONTEXT = ConfigBuilder("spark.log.callerContext") + .stringConf + .createOptional + + private[spark] val FILES_MAX_PARTITION_BYTES = ConfigBuilder("spark.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) + + private[spark] val FILES_OPEN_COST_IN_BYTES = ConfigBuilder("spark.files.openCostInBytes") + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimate, then the partitions with small files will be faster than partitions with" + + " bigger files.") + .longConf + .createWithDefault(4 * 1024 * 1024) + + private[spark] val SECRET_REDACTION_PATTERN = + ConfigBuilder("spark.redaction.regex") + .doc("Regex to decide which Spark configuration properties and environment variables in " + + "driver and executor environments contain sensitive information. When this regex matches " + + "a property key or value, the value is redacted from the environment UI and various logs " + + "like YARN and event logs.") + .regexConf + .createWithDefault("(?i)secret|password".r) + + private[spark] val STRING_REDACTION_PATTERN = + ConfigBuilder("spark.redaction.string.regex") + .doc("Regex to decide which parts of strings produced by Spark contain sensitive " + + "information. When this regex matches a string part, that string part is replaced by a " + + "dummy value. This is currently used to redact the output of SQL explain commands.") + .regexConf + .createOptional + + private[spark] val NETWORK_AUTH_ENABLED = + ConfigBuilder("spark.authenticate") + .booleanConf + .createWithDefault(false) + + private[spark] val SASL_ENCRYPTION_ENABLED = + ConfigBuilder("spark.authenticate.enableSaslEncryption") + .booleanConf + .createWithDefault(false) + + private[spark] val NETWORK_ENCRYPTION_ENABLED = + ConfigBuilder("spark.network.crypto.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val CHECKPOINT_COMPRESS = + ConfigBuilder("spark.checkpoint.compress") + .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + + "spark.io.compression.codec.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala new file mode 100644 index 0000000000000..7efa9416362a0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.fs._ +import org.apache.hadoop.mapreduce._ + +import org.apache.spark.util.Utils + + +/** + * An interface to define how a single Spark job commits its outputs. Two notes: + * + * 1. Implementations must be serializable, as the committer instance instantiated on the driver + * will be used for tasks on executors. + * 2. Implementations should have a constructor with either 2 or 3 arguments: + * (jobId: String, path: String) or (jobId: String, path: String, isAppend: Boolean). + * 3. A committer should not be reused across multiple Spark jobs. + * + * The proper call sequence is: + * + * 1. Driver calls setupJob. + * 2. As part of each task's execution, executor calls setupTask and then commitTask + * (or abortTask if task failed). + * 3. When all necessary tasks completed successfully, the driver calls commitJob. If the job + * failed to execute (e.g. too many failed tasks), the job should call abortJob. + */ +abstract class FileCommitProtocol { + import FileCommitProtocol._ + + /** + * Setups up a job. Must be called on the driver before any other methods can be invoked. + */ + def setupJob(jobContext: JobContext): Unit + + /** + * Commits a job after the writes succeed. Must be called on the driver. + */ + def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit + + /** + * Aborts a job after the writes fail. Must be called on the driver. + * + * Calling this function is a best-effort attempt, because it is possible that the driver + * just crashes (or killed) before it can call abort. + */ + def abortJob(jobContext: JobContext): Unit + + /** + * Sets up a task within a job. + * Must be called before any other task related methods can be invoked. + */ + def setupTask(taskContext: TaskAttemptContext): Unit + + /** + * Notifies the commit protocol to add a new file, and gets back the full path that should be + * used. Must be called on the executors when running tasks. + * + * Note that the returned temp file may have an arbitrary path. The commit protocol only + * promises that the file will be at the location specified by the arguments after job commit. + * + * A full file path consists of the following parts: + * 1. the base path + * 2. some sub-directory within the base path, used to specify partitioning + * 3. file prefix, usually some unique job id with the task id + * 4. bucket id + * 5. source specific file extension, e.g. ".snappy.parquet" + * + * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest + * are left to the commit protocol implementation to decide. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + + /** + * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. + * Depending on the implementation, there may be weaker guarantees around adding files this way. + * + * Important: it is the caller's responsibility to add uniquely identifying content to "ext" + * if a task is going to write out multiple files to the same dir. The file commit protocol only + * guarantees that files written by different tasks will not conflict. + */ + def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + + /** + * Commits a task after the writes succeed. Must be called on the executors when running tasks. + */ + def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage + + /** + * Aborts a task after the writes have failed. Must be called on the executors when running tasks. + * + * Calling this function is a best-effort attempt, because it is possible that the executor + * just crashes (or killed) before it can call abort. + */ + def abortTask(taskContext: TaskAttemptContext): Unit + + /** + * Specifies that a file should be deleted with the commit of this job. The default + * implementation deletes the file immediately. + */ + def deleteWithJob(fs: FileSystem, path: Path, recursive: Boolean): Boolean = { + fs.delete(path, recursive) + } + + /** + * Called on the driver after a task commits. This can be used to access task commit messages + * before the job has finished. These same task commit messages will be passed to commitJob() + * if the entire job succeeds. + */ + def onTaskCommit(taskCommit: TaskCommitMessage): Unit = {} +} + + +object FileCommitProtocol { + class TaskCommitMessage(val obj: Any) extends Serializable + + object EmptyTaskCommitMessage extends TaskCommitMessage(null) + + /** + * Instantiates a FileCommitProtocol using the given className. + */ + def instantiate(className: String, jobId: String, outputPath: String, isAppend: Boolean) + : FileCommitProtocol = { + val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + + // First try the one with argument (jobId: String, outputPath: String, isAppend: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, isAppend.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala new file mode 100644 index 0000000000000..22e26799138ba --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.util.{Date, UUID} + +import scala.collection.mutable + +import org.apache.hadoop.conf.Configurable +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.internal.Logging +import org.apache.spark.mapred.SparkHadoopMapRedUtil + +/** + * An [[FileCommitProtocol]] implementation backed by an underlying Hadoop OutputCommitter + * (from the newer mapreduce API, not the old mapred API). + * + * Unlike Hadoop's OutputCommitter, this implementation is serializable. + */ +class HadoopMapReduceCommitProtocol(jobId: String, path: String) + extends FileCommitProtocol with Serializable with Logging { + + import FileCommitProtocol._ + + /** OutputCommitter from Hadoop is not serializable so marking it transient. */ + @transient private var committer: OutputCommitter = _ + + /** + * Tracks files staged by this task for absolute output paths. These outputs are not managed by + * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. + * + * The mapping is from the temp output path to the final desired output path of the file. + */ + @transient private var addedAbsPathFiles: mutable.Map[String, String] = null + + /** + * The staging directory for all files committed with absolute output paths. + */ + private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + val format = context.getOutputFormatClass.newInstance() + // If OutputFormat is Configurable, we should set conf to it. + format match { + case c: Configurable => c.setConf(context.getConfiguration) + case _ => () + } + format.getOutputCommitter(context) + } + + override def newTaskTempFile( + taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { + val filename = getFilename(taskContext, ext) + + val stagingDir: String = committer match { + // For FileOutputCommitter it has its own staging path called "work path". + case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path) + case _ => path + } + + dir.map { d => + new Path(new Path(stagingDir, d), filename).toString + }.getOrElse { + new Path(stagingDir, filename).toString + } + } + + override def newTaskTempFileAbsPath( + taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + val filename = getFilename(taskContext, ext) + val absOutputPath = new Path(absoluteDir, filename).toString + + // Include a UUID here to prevent file collisions for one task writing to different dirs. + // In principle we could include hash(absoluteDir) instead but this is simpler. + val tmpOutputPath = new Path( + absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + + addedAbsPathFiles(tmpOutputPath) = absOutputPath + tmpOutputPath + } + + private def getFilename(taskContext: TaskAttemptContext, ext: String): String = { + // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + f"part-$split%05d-$jobId$ext" + } + + override def setupJob(jobContext: JobContext): Unit = { + // Setup IDs + val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) + val taskId = new TaskID(jobId, TaskType.MAP, 0) + val taskAttemptId = new TaskAttemptID(taskId, 0) + + // Set up the configuration object + jobContext.getConfiguration.set("mapreduce.job.id", jobId.toString) + jobContext.getConfiguration.set("mapreduce.task.id", taskAttemptId.getTaskID.toString) + jobContext.getConfiguration.set("mapreduce.task.attempt.id", taskAttemptId.toString) + jobContext.getConfiguration.setBoolean("mapreduce.task.ismap", true) + jobContext.getConfiguration.setInt("mapreduce.task.partition", 0) + + val taskAttemptContext = new TaskAttemptContextImpl(jobContext.getConfiguration, taskAttemptId) + committer = setupCommitter(taskAttemptContext) + committer.setupJob(jobContext) + } + + override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { + committer.commitJob(jobContext) + val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) + .foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) + } + + override def abortJob(jobContext: JobContext): Unit = { + committer.abortJob(jobContext, JobStatus.State.FAILED) + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } + + override def setupTask(taskContext: TaskAttemptContext): Unit = { + committer = setupCommitter(taskContext) + committer.setupTask(taskContext) + addedAbsPathFiles = mutable.Map[String, String]() + } + + override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { + val attemptId = taskContext.getTaskAttemptID + SparkHadoopMapRedUtil.commitTask( + committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) + new TaskCommitMessage(addedAbsPathFiles.toMap) + } + + override def abortTask(taskContext: TaskAttemptContext): Unit = { + committer.abortTask(taskContext) + // best effort cleanup of other staged files + for ((src, _) <- addedAbsPathFiles) { + val tmp = new Path(src) + tmp.getFileSystem(taskContext.getConfiguration).delete(tmp, false) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala new file mode 100644 index 0000000000000..376ff9bb19f74 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopMapReduceWriter.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.reflect.ClassTag +import scala.util.DynamicVariable + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + +import org.apache.spark.{SparkConf, SparkException, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage +import org.apache.spark.rdd.RDD +import org.apache.spark.util.{SerializableConfiguration, Utils} + +/** + * A helper object that saves an RDD using a Hadoop OutputFormat + * (from the newer mapreduce API, not the old mapred API). + */ +private[spark] +object SparkHadoopMapReduceWriter extends Logging { + + /** + * Basic work flow of this command is: + * 1. Driver side setup, prepare the data source and hadoop configuration for the write job to + * be issued. + * 2. Issues a write job consists of one or more executor side tasks, each of which writes all + * rows within an RDD partition. + * 3. If no exception is thrown in a task, commits that task, otherwise aborts that task; If any + * exception is thrown during task commitment, also aborts that task. + * 4. If all tasks are committed, commit the job, otherwise aborts the job; If any exception is + * thrown during job commitment, also aborts the job. + */ + def write[K, V: ClassTag]( + rdd: RDD[(K, V)], + hadoopConf: Configuration): Unit = { + // Extract context and configuration from RDD. + val sparkContext = rdd.context + val stageId = rdd.id + val sparkConf = rdd.conf + val conf = new SerializableConfiguration(hadoopConf) + + // Set up a job. + val jobTrackerId = SparkHadoopWriterUtils.createJobTrackerID(new Date()) + val jobAttemptId = new TaskAttemptID(jobTrackerId, stageId, TaskType.MAP, 0, 0) + val jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId) + val format = jobContext.getOutputFormatClass + + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(sparkConf)) { + // FileOutputFormat ignores the filesystem parameter + val jobFormat = format.newInstance + jobFormat.checkOutputSpecs(jobContext) + } + + val committer = FileCommitProtocol.instantiate( + className = classOf[HadoopMapReduceCommitProtocol].getName, + jobId = stageId.toString, + outputPath = conf.value.get("mapreduce.output.fileoutputformat.outputdir"), + isAppend = false).asInstanceOf[HadoopMapReduceCommitProtocol] + committer.setupJob(jobContext) + + // Try to write all RDD partitions as a Hadoop OutputFormat. + try { + val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { + executeTask( + context = context, + jobTrackerId = jobTrackerId, + sparkStageId = context.stageId, + sparkPartitionId = context.partitionId, + sparkAttemptNumber = context.attemptNumber, + committer = committer, + hadoopConf = conf.value, + outputFormat = format.asInstanceOf[Class[OutputFormat[K, V]]], + iterator = iter) + }) + + committer.commitJob(jobContext, ret) + logInfo(s"Job ${jobContext.getJobID} committed.") + } catch { + case cause: Throwable => + logError(s"Aborting job ${jobContext.getJobID}.", cause) + committer.abortJob(jobContext) + throw new SparkException("Job aborted.", cause) + } + } + + /** Write an RDD partition out in a single Spark task. */ + private def executeTask[K, V: ClassTag]( + context: TaskContext, + jobTrackerId: String, + sparkStageId: Int, + sparkPartitionId: Int, + sparkAttemptNumber: Int, + committer: FileCommitProtocol, + hadoopConf: Configuration, + outputFormat: Class[_ <: OutputFormat[K, V]], + iterator: Iterator[(K, V)]): TaskCommitMessage = { + // Set up a task. + val attemptId = new TaskAttemptID(jobTrackerId, sparkStageId, TaskType.REDUCE, + sparkPartitionId, sparkAttemptNumber) + val taskContext = new TaskAttemptContextImpl(hadoopConf, attemptId) + committer.setupTask(taskContext) + + val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) + + // Initiate the writer. + val taskFormat = outputFormat.newInstance() + // If OutputFormat is Configurable, we should set conf to it. + taskFormat match { + case c: Configurable => c.setConf(hadoopConf) + case _ => () + } + var writer = taskFormat.getRecordWriter(taskContext) + .asInstanceOf[RecordWriter[K, V]] + require(writer != null, "Unable to obtain RecordWriter") + var recordsWritten = 0L + + // Write all rows in RDD partition. + try { + val ret = Utils.tryWithSafeFinallyAndFailureCallbacks { + // Write rows out, release resource and commit the task. + while (iterator.hasNext) { + val pair = iterator.next() + writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) + recordsWritten += 1 + } + if (writer != null) { + writer.close(taskContext) + writer = null + } + committer.commitTask(taskContext) + }(catchBlock = { + // If there is an error, release resource and then abort the task. + try { + if (writer != null) { + writer.close(taskContext) + writer = null + } + } finally { + committer.abortTask(taskContext) + logError(s"Task ${taskContext.getTaskAttemptID} aborted.") + } + }) + + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + + ret + } catch { + case t: Throwable => + throw new SparkException("Task failed while writing rows", t) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala similarity index 83% rename from core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala rename to core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 6550d703bc860..acc9c38571007 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -15,18 +15,17 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.internal.io import java.io.IOException -import java.text.NumberFormat -import java.text.SimpleDateFormat -import java.util.Date +import java.text.{NumberFormat, SimpleDateFormat} +import java.util.{Date, Locale} import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.TaskType +import org.apache.spark.SerializableWritable import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -67,12 +66,12 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { def setup(jobid: Int, splitid: Int, attemptid: Int) { setIDs(jobid, splitid, attemptid) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(now), + HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(now), jobid, splitID, attemptID, conf.value) } def open() { - val numfmt = NumberFormat.getInstance() + val numfmt = NumberFormat.getInstance(Locale.US) numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) @@ -153,29 +152,8 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { splitID = splitid attemptID = attemptid - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) + jID = new SerializableWritable[JobID](SparkHadoopWriterUtils.createJobID(now, jobid)) taID = new SerializableWritable[TaskAttemptID]( new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } } - -private[spark] -object SparkHadoopWriter { - def createJobID(time: Date, id: Int): JobID = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") - val jobtrackerID = formatter.format(time) - new JobID(jobtrackerID, id) - } - - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } -} diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala new file mode 100644 index 0000000000000..de828a6d6156e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriterUtils.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import java.text.SimpleDateFormat +import java.util.{Date, Locale} + +import scala.util.DynamicVariable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred.{JobConf, JobID} + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.OutputMetrics + +/** + * A helper object that provide common utils used during saving an RDD using a Hadoop OutputFormat + * (both from the old mapred API and the new mapreduce API) + */ +private[spark] +object SparkHadoopWriterUtils { + + private val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 + + def createJobID(time: Date, id: Int): JobID = { + val jobtrackerID = createJobTrackerID(time) + new JobID(jobtrackerID, id) + } + + def createJobTrackerID(time: Date): String = { + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time) + } + + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + + // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation + // setting can take effect: + def isOutputSpecValidationEnabled(conf: SparkConf): Boolean = { + val validationDisabled = disableOutputSpecValidation.value + val enabledInConf = conf.getBoolean("spark.hadoop.validateOutputSpecs", true) + enabledInConf && !validationDisabled + } + + // TODO: these don't seem like the right abstractions. + // We should abstract the duplicate code in a less awkward way. + + def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, () => Long) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() + (context.taskMetrics().outputMetrics, bytesWrittenCallback) + } + + def maybeUpdateOutputMetrics( + outputMetrics: OutputMetrics, + callback: () => Long, + recordsWritten: Long): Unit = { + if (recordsWritten % RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) + } + } + + /** + * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case + * basis; see SPARK-4835 for more details. + */ + val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ae014becef755..0cb16f0627b72 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -18,6 +18,7 @@ package org.apache.spark.io import java.io._ +import java.util.Locale import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import net.jpountz.lz4.LZ4BlockOutputStream @@ -32,9 +33,8 @@ import org.apache.spark.util.Utils * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. * - * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. - * This is intended for use as an internal compression utility within a single - * Spark application. + * @note The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single Spark application. */ @DeveloperApi trait CompressionCodec { @@ -67,13 +67,13 @@ private[spark] object CompressionCodec { } def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { - val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName) + val codecClass = + shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName) val codec = try { val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) } catch { - case e: ClassNotFoundException => None - case e: IllegalArgumentException => None + case _: ClassNotFoundException | _: IllegalArgumentException => None } codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " + s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC")) @@ -103,9 +103,9 @@ private[spark] object CompressionCodec { * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.lz4.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -123,9 +123,9 @@ class LZ4CompressionCodec(conf: SparkConf) extends CompressionCodec { * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -143,9 +143,9 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by `spark.io.compression.snappy.blockSize`. * - * Note: The wire protocol for this codec is not guaranteed to be compatible across versions - * of Spark. This is intended for use as an internal compression utility within a single Spark - * application. + * @note The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ @DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { @@ -173,7 +173,7 @@ private final object SnappyCompressionCodec { } /** - * Wrapper over [[SnappyOutputStream]] which guards against write-after-close and double-close + * Wrapper over `SnappyOutputStream` which guards against write-after-close and double-close * issues. See SPARK-7660 for more details. This wrapping can be removed if we upgrade to a version * of snappy-java that contains the fix for https://github.com/xerial/snappy-java/issues/107. */ diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index 31b9c5edf003f..4216b2627309e 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -39,8 +39,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) - CommandBuilderUtils.addPermGenSizeOpt(cmd) - addOptionString(cmd, getenv("SPARK_JAVA_OPTS")) cmd } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 81b9056b40fb8..fce556fd0382c 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{ConsoleReporter, MetricRegistry} @@ -39,7 +39,7 @@ private[spark] class ConsoleSink(val property: Properties, val registry: MetricR } val pollUnit: TimeUnit = Option(property.getProperty(CONSOLE_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CONSOLE_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 9d5f2ae9328ad..88bba2fdbd1c6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -42,7 +42,7 @@ private[spark] class CsvSink(val property: Properties, val registry: MetricRegis } val pollUnit: TimeUnit = Option(property.getProperty(CSV_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 22454e50b14b4..23e31823f4930 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -18,7 +18,7 @@ package org.apache.spark.metrics.sink import java.net.InetSocketAddress -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry @@ -59,7 +59,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric } val pollUnit: TimeUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) } @@ -67,7 +67,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase(Locale.ROOT)) match { case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 773e074336cb0..7fa4ba7622980 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit import com.codahale.metrics.{MetricRegistry, Slf4jReporter} @@ -42,7 +42,7 @@ private[spark] class Slf4jSink( } val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { - case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case Some(s) => TimeUnit.valueOf(s.toUpperCase(Locale.ROOT)) case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) } diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala index 3f7cfd9d2c11f..99ec78633ab75 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala @@ -85,6 +85,17 @@ object HiveCatalogMetrics extends Source { */ val METRIC_FILE_CACHE_HITS = metricRegistry.counter(MetricRegistry.name("fileCacheHits")) + /** + * Tracks the total number of Hive client calls (e.g. to lookup a table). + */ + val METRIC_HIVE_CLIENT_CALLS = metricRegistry.counter(MetricRegistry.name("hiveClientCalls")) + + /** + * Tracks the total number of Spark jobs launched for parallel file listing. + */ + val METRIC_PARALLEL_LISTING_JOB_COUNT = metricRegistry.counter( + MetricRegistry.name("parallelListingJobCount")) + /** * Resets the values of all metrics to zero. This is useful in tests. */ @@ -92,10 +103,14 @@ object HiveCatalogMetrics extends Source { METRIC_PARTITIONS_FETCHED.dec(METRIC_PARTITIONS_FETCHED.getCount()) METRIC_FILES_DISCOVERED.dec(METRIC_FILES_DISCOVERED.getCount()) METRIC_FILE_CACHE_HITS.dec(METRIC_FILE_CACHE_HITS.getCount()) + METRIC_HIVE_CLIENT_CALLS.dec(METRIC_HIVE_CLIENT_CALLS.getCount()) + METRIC_PARALLEL_LISTING_JOB_COUNT.dec(METRIC_PARALLEL_LISTING_JOB_COUNT.getCount()) } // clients can use these to avoid classloader issues with the codahale classes def incrementFetchedPartitions(n: Int): Unit = METRIC_PARTITIONS_FETCHED.inc(n) def incrementFilesDiscovered(n: Int): Unit = METRIC_FILES_DISCOVERED.inc(n) def incrementFileCacheHits(n: Int): Unit = METRIC_FILE_CACHE_HITS.inc(n) + def incrementHiveClientCalls(n: Int): Unit = METRIC_HIVE_CLIENT_CALLS.inc(n) + def incrementParallelListingJobCount(n: Int): Unit = METRIC_PARALLEL_LISTING_JOB_COUNT.inc(n) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df7023..305fd9a6de10d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,11 +56,12 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val blocksNum = openBlocks.blockIds.length + val blocks = for (i <- (0 until blocksNum).view) + yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + logTrace(s"Registered streamId $streamId with $blocksNum buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index dc70eb82d2b54..b75e91b660969 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -27,7 +27,7 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} -import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock @@ -37,7 +37,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils /** - * A BlockTransferService that uses Netty to fetch a set of blocks at at time. + * A BlockTransferService that uses Netty to fetch a set of blocks at time. */ private[spark] class NettyBlockTransferService( conf: SparkConf, @@ -63,9 +63,8 @@ private[spark] class NettyBlockTransferService( var serverBootstrap: Option[TransportServerBootstrap] = None var clientBootstrap: Option[TransportClientBootstrap] = None if (authEnabled) { - serverBootstrap = Some(new SaslServerBootstrap(transportConf, securityManager)) - clientBootstrap = Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager, - securityManager.isSaslEncryptionEnabled())) + serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager)) + clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager)) } transportContext = new TransportContext(transportConf, rpcHandler) clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava) diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 86874e2067dd4..25f7bcb9801b9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import scala.collection.JavaConverters._ + import org.apache.spark.SparkConf import org.apache.spark.network.util.{ConfigProvider, TransportConf} @@ -58,6 +60,10 @@ object SparkTransportConf { new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) + override def get(name: String, defaultValue: String): String = conf.get(name, defaultValue) + override def getAll(): java.lang.Iterable[java.util.Map.Entry[String, String]] = { + conf.getAll.toMap.asJava.entrySet() + } }) } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index ab6aba6fc7d6a..8f579c5a3033c 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,7 +28,7 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false + * @note Consistent with Double, any NaN value will make equality false */ override def equals(that: Any): Boolean = that match { diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index 41832e8354741..50d977a92da51 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat private[spark] class BinaryFileRDD[T]( - sc: SparkContext, + @transient private val sc: SparkContext, inputFormatClass: Class[_ <: StreamFileInputFormat[T]], keyClass: Class[String], valueClass: Class[T], @@ -43,7 +43,7 @@ private[spark] class BinaryFileRDD[T]( case _ => } val jobContext = new JobContextImpl(conf, jobId) - inputFormat.setMinPartitions(jobContext, minPartitions) + inputFormat.setMinPartitions(sc, jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index d47b75544fdba..4e036c2ed49b5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -47,7 +47,7 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo blockManager.get[T](blockId) match { case Some(block) => block.data.asInstanceOf[Iterator[T]] case None => - throw new Exception("Could not compute split, block " + blockId + " not found") + throw new Exception(s"Could not compute split, block $blockId of RDD $id not found") } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 2381f54ee3f06..a091f06b4ed7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -66,14 +66,14 @@ private[spark] class CoGroupPartition( /** * :: DeveloperApi :: - * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a + * An RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * - * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of - * instantiating this directly. - * * @param rdds parent RDDs. * @param part partitioner used to partition the shuffle output + * + * @note This is an internal API. We recommend users use RDD.cogroup(...) instead of + * instantiating this directly. */ @DeveloperApi class CoGroupedRDD[K: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a05a770b40c57..14331dfd0c987 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -152,13 +152,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** * Compute a histogram using the provided buckets. The buckets are all open - * to the right except for the last which is closed + * to the right except for the last which is closed. * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<=50 + * e.g {@code <=x<10, 10<=x<20, 20<=x<=50} * And on the input of 1 and 50 we would have a histogram of 1, 0, 1 * - * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * @note If your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index e1cf3938de098..4bf8ecc383542 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -19,21 +19,13 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.collection.immutable.Map import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.mapred.FileSplit -import org.apache.hadoop.mapred.InputFormat -import org.apache.hadoop.mapred.InputSplit -import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapred.JobID -import org.apache.hadoop.mapred.RecordReader -import org.apache.hadoop.mapred.Reporter -import org.apache.hadoop.mapred.TaskAttemptID -import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred._ import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.mapreduce.TaskType import org.apache.hadoop.util.ReflectionUtils @@ -47,7 +39,7 @@ import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager, Utils} +import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager} /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -84,9 +76,6 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.hadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -97,6 +86,9 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param minPartitions Minimum number of HadoopRDD partitions (Hadoop Splits) to generate. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.hadoopRDD()` */ @DeveloperApi class HadoopRDD[K, V]( @@ -131,9 +123,9 @@ class HadoopRDD[K, V]( minPartitions) } - protected val jobConfCacheKey = "rdd_%d_job_conf".format(id) + protected val jobConfCacheKey: String = "rdd_%d_job_conf".format(id) - protected val inputFormatCacheKey = "rdd_%d_input_format".format(id) + protected val inputFormatCacheKey: String = "rdd_%d_input_format".format(id) // used to build JobTracker ID private val createTime = new Date() @@ -210,53 +202,66 @@ class HadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new NextIterator[(K, V)] { - val split = theSplit.asInstanceOf[HadoopPartition] + private val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - val jobConf = getJobConf() + private val jobConf = getJobConf() - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead - // Sets the thread local variable for the file's name + // Sets InputFileBlockHolder for the file block's information split.inputSplit.value match { - case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) - case _ => InputFileNameHolder.unsetInputFileName() + case fs: FileSplit => + InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) + case _ => + InputFileBlockHolder.unset() } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { + private val getBytesReadCallback: Option[() => Long] = split.inputSplit.value match { case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()) case _ => None } - // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // We get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - var reader: RecordReader[K, V] = null - val inputFormat = getInputFormat(jobConf) - HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmmss").format(createTime), + private var reader: RecordReader[K, V] = null + private val inputFormat = getInputFormat(jobConf) + HadoopRDD.addLocalConfiguration( + new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(createTime), context.stageId, theSplit.index, context.attemptNumber, jobConf) - reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + reader = + try { + inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener{ context => closeIfNeeded() } - val key: K = reader.createKey() - val value: V = reader.createValue() + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() + private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() override def getNext(): (K, V) = { try { finished = !reader.next(key, value) } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning(s"Skipped the rest content in the corrupted file: ${split.inputSplit}", e) + finished = true } if (!finished) { inputMetrics.incRecordsRead(1) @@ -267,13 +272,9 @@ class HadoopRDD[K, V]( (key, value) } - override def close() { + override def close(): Unit = { if (reader != null) { - InputFileNameHolder.unsetInputFileName() - // Close the reader and release it. Note: it's very important that we don't close the - // reader more than once, since that exposes us to MAPREDUCE-5918 when running against - // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic - // corruption issues when reading compressed input. + InputFileBlockHolder.unset() try { reader.close() } catch { @@ -313,18 +314,10 @@ class HadoopRDD[K, V]( override def getPreferredLocations(split: Partition): Seq[String] = { val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value - val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val lsplit = c.inputSplitWithLocationInfo.cast(hsplit) - val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]] - HadoopRDD.convertSplitLocationInfo(infos) - } catch { - case e: Exception => - logDebug("Failed to use InputSplitWithLocations.", e) - None - } - case None => None + val locs = hsplit match { + case lsplit: InputSplitWithLocationInfo => + HadoopRDD.convertSplitLocationInfo(lsplit.getLocationInfo) + case _ => None } locs.getOrElse(hsplit.getLocations.filter(_ != "localhost")) } @@ -372,11 +365,11 @@ private[spark] object HadoopRDD extends Logging { val jobID = new JobID(jobTrackerId, jobId) val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) - conf.set("mapred.tip.id", taId.getTaskID.toString) - conf.set("mapred.task.id", taId.toString) - conf.setBoolean("mapred.task.is.map", true) - conf.setInt("mapred.task.partition", splitId) - conf.set("mapred.job.id", jobID.toString) + conf.set("mapreduce.task.id", taId.getTaskID.toString) + conf.set("mapreduce.task.attempt.id", taId.toString) + conf.setBoolean("mapreduce.task.ismap", true) + conf.setInt("mapreduce.task.partition", splitId) + conf.set("mapreduce.job.id", jobID.toString) } /** @@ -400,32 +393,12 @@ private[spark] object HadoopRDD extends Logging { } } - private[spark] class SplitInfoReflections { - val inputSplitWithLocationInfo = - Utils.classForName("org.apache.hadoop.mapred.InputSplitWithLocationInfo") - val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo") - val newInputSplit = Utils.classForName("org.apache.hadoop.mapreduce.InputSplit") - val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo") - val splitLocationInfo = Utils.classForName("org.apache.hadoop.mapred.SplitLocationInfo") - val isInMemory = splitLocationInfo.getMethod("isInMemory") - val getLocation = splitLocationInfo.getMethod("getLocation") - } - - private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try { - Some(new SplitInfoReflections) - } catch { - case e: Exception => - logDebug("SplitLocationInfo and other new Hadoop classes are " + - "unavailable. Using the older Hadoop location info code.", e) - None - } - - private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Option[Seq[String]] = { + private[spark] def convertSplitLocationInfo( + infos: Array[SplitLocationInfo]): Option[Seq[String]] = { Option(infos).map(_.flatMap { loc => - val reflections = HadoopRDD.SPLIT_INFO_REFLECTIONS.get - val locationStr = reflections.getLocation.invoke(loc).asInstanceOf[String] + val locationStr = loc.getLocation if (locationStr != "localhost") { - if (reflections.isInMemory.invoke(loc).asInstanceOf[Boolean]) { + if (loc.isInMemory) { logDebug(s"Partition $locationStr is cached by Hadoop.") Some(HDFSCacheTaskLocation(locationStr).toString) } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala new file mode 100644 index 0000000000000..ff2f58d81142d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileBlockHolder.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD, NewHadoopRDD and InputFileName function in Spark SQL. + */ +private[spark] object InputFileBlockHolder { + /** + * A wrapper around some input file information. + * + * @param filePath path of the file read, or empty string if not available. + * @param startOffset starting offset, in bytes, or -1 if not available. + * @param length size of the block, in bytes, or -1 if not available. + */ + private class FileBlock(val filePath: UTF8String, val startOffset: Long, val length: Long) { + def this() { + this(UTF8String.fromString(""), -1, -1) + } + } + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputBlock: InheritableThreadLocal[FileBlock] = + new InheritableThreadLocal[FileBlock] { + override protected def initialValue(): FileBlock = new FileBlock + } + + /** + * Returns the holding file name or empty string if it is unknown. + */ + def getInputFilePath: UTF8String = inputBlock.get().filePath + + /** + * Returns the starting offset of the block currently being read, or -1 if it is unknown. + */ + def getStartOffset: Long = inputBlock.get().startOffset + + /** + * Returns the length of the block being read, or -1 if it is unknown. + */ + def getLength: Long = inputBlock.get().length + + /** + * Sets the thread-local input block. + */ + def set(filePath: String, startOffset: Long, length: Long): Unit = { + require(filePath != null, "filePath cannot be null") + require(startOffset >= 0, s"startOffset ($startOffset) cannot be negative") + require(length >= 0, s"length ($length) cannot be negative") + inputBlock.set(new FileBlock(UTF8String.fromString(filePath), startOffset, length)) + } + + /** + * Clears the input file block to default value. + */ + def unset(): Unit = inputBlock.remove() +} diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0970b98071675..aab46b8954bf7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -41,7 +41,10 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. @@ -151,7 +154,10 @@ object JdbcRDD { * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. @@ -191,7 +197,10 @@ object JdbcRDD { * The RDD takes care of closing the connection. * @param sql the text of the query. * The query must contain two ? placeholders for parameters used to partition the results. - * E.g. "select title, author from books where ? <= id and id <= ?" + * For example, + * {{{ + * select title, author from books where ? <= id and id <= ? + * }}} * @param lowerBound the minimum value of the first placeholder * @param upperBound the maximum value of the second placeholder * The lower and upper bounds are inclusive. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index baf31fb658870..ce3a9a2a1e2a8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{Date, Locale} import scala.reflect.ClassTag @@ -57,13 +57,13 @@ private[spark] class NewHadoopPartition( * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * - * Note: Instantiating this class directly is not recommended, please use - * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] - * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. + * + * @note Instantiating this class directly is not recommended, please use + * `org.apache.spark.SparkContext.newAPIHadoopRDD()` */ @DeveloperApi class NewHadoopRDD[K, V]( @@ -79,7 +79,7 @@ class NewHadoopRDD[K, V]( // private val serializableConf = new SerializableWritable(_conf) private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") + val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) formatter.format(new Date()) } @@ -132,61 +132,79 @@ class NewHadoopRDD[K, V]( override def compute(theSplit: Partition, context: TaskContext): InterruptibleIterator[(K, V)] = { val iter = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[NewHadoopPartition] + private val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf + private val conf = getConf - val inputMetrics = context.taskMetrics().inputMetrics - val existingBytesRead = inputMetrics.bytesRead + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead - // Sets the thread local variable for the file's name + // Sets InputFileBlockHolder for the file block's information split.serializableHadoopSplit.value match { - case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) - case _ => InputFileNameHolder.unsetInputFileName() + case fs: FileSplit => + InputFileBlockHolder.set(fs.getPath.toString, fs.getStart, fs.getLength) + case _ => + InputFileBlockHolder.unset() } // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } + private val getBytesReadCallback: Option[() => Long] = + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + Some(SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()) + case _ => None + } - // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // We get our input bytes from thread-local Hadoop FileSystem statistics. // If we do a coalesce, however, we are likely to compute multiple partitions in the same // task and in the same thread, in which case we need to avoid override values written by // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { + private def updateBytesRead(): Unit = { getBytesReadCallback.foreach { getBytesRead => inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) } } - val format = inputFormatClass.newInstance + private val format = inputFormatClass.newInstance format match { case configurable: Configurable => configurable.setConf(conf) case _ => } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + private val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + private val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) + private var finished = false + private var reader = + try { + val _reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + _reader + } catch { + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true + null + } // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 + private var havePair = false + private var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { try { finished = !reader.nextKeyValue } catch { - case e: IOException if ignoreCorruptFiles => finished = true + case e: IOException if ignoreCorruptFiles => + logWarning( + s"Skipped the rest content in the corrupted file: ${split.serializableHadoopSplit}", + e) + finished = true } if (finished) { // Close and release the reader here; close() will also be called when the task @@ -213,13 +231,9 @@ class NewHadoopRDD[K, V]( (reader.getCurrentKey, reader.getCurrentValue) } - private def close() { + private def close(): Unit = { if (reader != null) { - InputFileNameHolder.unsetInputFileName() - // Close the reader and release it. Note: it's very important that we don't close the - // reader more than once, since that exposes us to MAPREDUCE-5918 when running against - // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic - // corruption issues when reading compressed input. + InputFileBlockHolder.unset() try { reader.close() } catch { @@ -259,18 +273,7 @@ class NewHadoopRDD[K, V]( override def getPreferredLocations(hsplit: Partition): Seq[String] = { val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value - val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - HadoopRDD.convertSplitLocationInfo(infos) - } catch { - case e : Exception => - logDebug("Failed to use InputSplit#getLocationInfo.", e) - None - } - case None => None - } + val locs = HadoopRDD.convertSplitLocationInfo(split.getLocationInfo) locs.getOrElse(split.getLocations.filter(_ != "localhost")) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 068f4ed8ad745..58762cc0838cd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -18,33 +18,31 @@ package org.apache.spark.rdd import java.nio.ByteBuffer -import java.text.SimpleDateFormat -import java.util.{Date, HashMap => JHashMap} +import java.util.{HashMap => JHashMap} import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import scala.util.DynamicVariable import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus -import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType} -import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.OutputMetrics +import org.apache.spark.internal.io.{SparkHadoopMapReduceWriter, SparkHadoopWriter, + SparkHadoopWriterUtils} import org.apache.spark.internal.Logging import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.random.StratifiedSamplingUtils @@ -59,8 +57,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * :: Experimental :: * Generic function to combine the elements for each key using a custom set of aggregation * functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C - * Note that V and C can be different -- for example, one might group an RDD of type - * (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions: + * + * Users provide three functions: * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) @@ -68,6 +66,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * * In addition, users can control the partitioning of the output RDD, and whether to perform * map-side aggregation (if a mapper can produce multiple items with the same key). + * + * @note V and C can be different -- for example, one might group an RDD of type + * (Int, Int) into an RDD of type (Int, Seq[Int]). */ @Experimental def combineByKeyWithClassTag[C]( @@ -108,7 +109,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * functions. This method is here for backward compatibility. It does not provide combiner * classtag information to the shuffle. * - * @see [[combineByKeyWithClassTag]] + * @see `combineByKeyWithClassTag` */ def combineByKey[C]( createCombiner: V => C, @@ -126,7 +127,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * This method is here for backward compatibility. It does not provide combiner * classtag information to the shuffle. * - * @see [[combineByKeyWithClassTag]] + * @see `combineByKeyWithClassTag` */ def combineByKey[C]( createCombiner: V => C, @@ -363,7 +364,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) /** * Count the number of elements for each key, collecting the results to a local Map. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. @@ -398,9 +399,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` - * would trigger sparse representation of registers, which may reduce the memory consumption - * and increase accuracy when the cardinality is small. + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is + * greater than `p`) would trigger sparse representation of registers, which may reduce the + * memory consumption and increase accuracy when the cardinality is small. * * @param p The precision value for the normal set. * `p` must be a value between 4 and `sp` if `sp` is not zero (32 max). @@ -490,12 +491,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * The ordering of elements within each group is not guaranteed, and may even differ * each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any - * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * key in memory. If a key has too many values, it can result in an `OutOfMemoryError`. */ def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = self.withScope { // groupByKey shouldn't use map side combine because map side combine does not @@ -514,12 +515,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * resulting RDD with into `numPartitions` partitions. The ordering of elements within * each group is not guaranteed, and may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. * - * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any - * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]]. + * @note As currently implemented, groupByKey must be able to hold all the key-value pairs for any + * key in memory. If a key has too many values, it can result in an `OutOfMemoryError`. */ def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = self.withScope { groupByKey(new HashPartitioner(numPartitions)) @@ -607,7 +608,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * existing partitioner/parallelism level. This method is here for backward compatibility. It * does not provide combiner classtag information to the shuffle. * - * @see [[combineByKeyWithClassTag]] + * @see `combineByKeyWithClassTag` */ def combineByKey[C]( createCombiner: V => C, @@ -635,9 +636,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * within each group is not guaranteed, and may even differ each time the resulting RDD is * evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupByKey(): RDD[(K, Iterable[V])] = self.withScope { groupByKey(defaultPartitioner(self)) @@ -907,20 +908,24 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Return an RDD with the pairs from `this` whose keys are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be less than or equal to us. */ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = self.withScope { subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.length))) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W: ClassTag]( other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = self.withScope { subtractByKey(other, new HashPartitioner(numPartitions)) } - /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ + /** + * Return an RDD with the pairs from `this` whose keys are not in `other`. + */ def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = self.withScope { new SubtractedRDD[K, V, W](self, other, p) } @@ -994,7 +999,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) val jobConfiguration = job.getConfiguration - jobConfiguration.set("mapred.output.dir", path) + jobConfiguration.set("mapreduce.output.fileoutputformat.outputdir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1016,7 +1021,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. @@ -1035,10 +1040,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf.setOutputFormat(outputFormatClass) for (c <- codec) { hadoopConf.setCompressMapOutput(true) - hadoopConf.set("mapred.output.compress", "true") + hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") hadoopConf.setMapOutputCompressorClass(c) - hadoopConf.set("mapred.output.compression.codec", c.getCanonicalName) - hadoopConf.set("mapred.output.compression.type", CompressionType.BLOCK.toString) + hadoopConf.set("mapreduce.output.fileoutputformat.compress.codec", c.getCanonicalName) + hadoopConf.set("mapreduce.output.fileoutputformat.compress.type", + CompressionType.BLOCK.toString) } // Use configured output committer if already set @@ -1060,7 +1066,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } FileOutputFormat.setOutputPath(hadoopConf, - SparkHadoopWriter.createPathFromString(path, hadoopConf)) + SparkHadoopWriterUtils.createPathFromString(path, hadoopConf)) saveAsHadoopDataset(hadoopConf) } @@ -1070,86 +1076,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) * output paths required (e.g. a table name to write to) in the same way as it would be * configured for a Hadoop MapReduce job. * - * Note that, we should make sure our tasks are idempotent when speculation is enabled, i.e. do + * @note We should make sure our tasks are idempotent when speculation is enabled, i.e. do * not use output committer that writes data directly. * There is an example in https://issues.apache.org/jira/browse/SPARK-10063 to show the bad * result of using direct output committer with speculation enabled. */ def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { - // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). - val hadoopConf = conf - val job = NewAPIHadoopJob.getInstance(hadoopConf) - val formatter = new SimpleDateFormat("yyyyMMddHHmmss") - val jobtrackerID = formatter.format(new Date()) - val stageId = self.id - val jobConfiguration = job.getConfiguration - val wrappedConf = new SerializableConfiguration(jobConfiguration) - val outfmt = job.getOutputFormatClass - val jobFormat = outfmt.newInstance - - if (isOutputSpecValidationEnabled) { - // FileOutputFormat ignores the filesystem parameter - jobFormat.checkOutputSpecs(job) - } - - val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { - val config = wrappedConf.value - /* "reduce task" */ - val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, - context.attemptNumber) - val hadoopContext = new TaskAttemptContextImpl(config, attemptId) - val format = outfmt.newInstance - format match { - case c: Configurable => c.setConf(config) - case _ => () - } - val committer = format.getOutputCommitter(hadoopContext) - committer.setupTask(hadoopContext) - - val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = - initHadoopOutputMetrics(context) - - val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]] - require(writer != null, "Unable to obtain RecordWriter") - var recordsWritten = 0L - Utils.tryWithSafeFinallyAndFailureCallbacks { - while (iter.hasNext) { - val pair = iter.next() - writer.write(pair._1, pair._2) - - // Update bytes written metric every few records - maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) - recordsWritten += 1 - } - }(finallyBlock = writer.close(hadoopContext)) - committer.commitTask(hadoopContext) - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } - 1 - } : Int - - val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) - val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) - val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) - - // When speculation is on and output committer class name contains "Direct", we should warn - // users that they may loss data if they are using a direct output committer. - val speculationEnabled = self.conf.getBoolean("spark.speculation", false) - val outputCommitterClass = jobCommitter.getClass.getSimpleName - if (speculationEnabled && outputCommitterClass.contains("Direct")) { - val warningMessage = - s"$outputCommitterClass may be an output committer that writes data directly to " + - "the final location. Because speculation is enabled, this output committer may " + - "cause data loss (see the case in SPARK-10063). If possible, please use an output " + - "committer that does not have this behavior (e.g. FileOutputCommitter)." - logWarning(warningMessage) - } - - jobCommitter.setupJob(jobTaskContext) - self.context.runJob(self, writeShard) - jobCommitter.commitJob(jobTaskContext) + SparkHadoopMapReduceWriter.write( + rdd = self, + hadoopConf = conf) } /** @@ -1178,7 +1113,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " + valueClass.getSimpleName + ")") - if (isOutputSpecValidationEnabled) { + if (SparkHadoopWriterUtils.isOutputSpecValidationEnabled(self.conf)) { // FileOutputFormat ignores the filesystem parameter val ignoredFs = FileSystem.get(hadoopConf) hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf) @@ -1192,8 +1127,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] = - initHadoopOutputMetrics(context) + val (outputMetrics, callback) = SparkHadoopWriterUtils.initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1205,44 +1139,19 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) // Update bytes written metric every few records - maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten) + SparkHadoopWriterUtils.maybeUpdateOutputMetrics(outputMetrics, callback, recordsWritten) recordsWritten += 1 } }(finallyBlock = writer.close()) writer.commit() - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } + outputMetrics.setBytesWritten(callback()) + outputMetrics.setRecordsWritten(recordsWritten) } self.context.runJob(self, writeToFile) writer.commitJob() } - // TODO: these don't seem like the right abstractions. - // We should abstract the duplicate code in a less awkward way. - - // return type: (output metrics, bytes written callback), defined only if the latter is defined - private def initHadoopOutputMetrics( - context: TaskContext): Option[(OutputMetrics, () => Long)] = { - val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() - bytesWrittenCallback.map { b => - (context.taskMetrics().outputMetrics, b) - } - } - - private def maybeUpdateOutputMetrics( - outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)], - recordsWritten: Long): Unit = { - if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) { - outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) => - om.setBytesWritten(callback()) - om.setRecordsWritten(recordsWritten) - } - } - } - /** * Return an RDD with the keys of each tuple. */ @@ -1258,22 +1167,4 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def valueClass: Class[_] = vt.runtimeClass private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) - - // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation - // setting can take effect: - private def isOutputSpecValidationEnabled: Boolean = { - val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value - val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true) - enabledInConf && !validationDisabled - } -} - -private[spark] object PairRDDFunctions { - val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 - - /** - * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case - * basis; see SPARK-4835 for more details. - */ - val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index e9092739b298a..9f8019b80a4dd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -116,7 +116,7 @@ private object ParallelCollectionRDD { */ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { - throw new IllegalArgumentException("Positive number of slices required") + throw new IllegalArgumentException("Positive number of partitions required") } // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 0c6ddda52cee9..ce75a16031a3f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -48,7 +48,7 @@ private[spark] class PruneDependency[T](rdd: RDD[T], partitionFilterFunc: Int => /** * :: DeveloperApi :: - * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on + * An RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index 3b1acacf409b9..6a89ea8786464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -32,7 +32,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) } /** - * A RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, + * An RDD sampled from its parent RDD partition-wise. For each partition of the parent RDD, * a user-specified [[org.apache.spark.util.random.RandomSampler]] instance is used to obtain * a random sample of the records in the partition. The random seeds assigned to the samplers * are guaranteed to have different values. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index db535de9e9bb3..63a87e7f09d85 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -70,8 +70,8 @@ import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, Poi * All of the scheduling and execution in Spark is done based on these methods, allowing each RDD * to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for * reading data from a new storage system) by overriding these functions. Please refer to the - * [[http://people.csail.mit.edu/matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details - * on RDD internals. + * Spark paper + * for more details on RDD internals. */ abstract class RDD[T: ClassTag]( @transient private var _sc: SparkContext, @@ -195,10 +195,14 @@ abstract class RDD[T: ClassTag]( } } - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def persist(): this.type = persist(StorageLevel.MEMORY_ONLY) - /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ + /** + * Persist this RDD with the default storage level (`MEMORY_ONLY`). + */ def cache(): this.type = persist() /** @@ -419,7 +423,8 @@ abstract class RDD[T: ClassTag]( * * This results in a narrow dependency, e.g. if you go from 1000 partitions * to 100 partitions, there will not be a shuffle, instead each of the 100 - * new partitions will claim 10 of the current partitions. + * new partitions will claim 10 of the current partitions. If a larger number + * of partitions is requested, it will stay at the current number of partitions. * * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, * this may result in your computation taking place on fewer nodes than @@ -428,7 +433,7 @@ abstract class RDD[T: ClassTag]( * current upstream partitions will be executed in parallel (per whatever * the current partitioning is). * - * Note: With shuffle = true, you can actually coalesce to a larger number + * @note With shuffle = true, you can actually coalesce to a larger number * of partitions. This is useful if you have a small number of partitions, * say 100, potentially with a few partitions being abnormally large. Calling * coalesce(1000, shuffle = true) will result in 1000 partitions with the @@ -469,8 +474,12 @@ abstract class RDD[T: ClassTag]( * @param withReplacement can elements be sampled multiple times (replaced when sampled out) * @param fraction expected size of the sample as a fraction of this RDD's size * without replacement: probability that each element is chosen; fraction must be [0, 1] - * with replacement: expected number of times each element is chosen; fraction must be >= 0 + * with replacement: expected number of times each element is chosen; fraction must be greater + * than or equal to 0 * @param seed seed for the random number generator + * + * @note This is NOT guaranteed to provide exactly the fraction of the count + * of the given [[RDD]]. */ def sample( withReplacement: Boolean, @@ -534,13 +543,13 @@ abstract class RDD[T: ClassTag]( /** * Return a fixed-size sampled subset of this RDD in an array * - * @note this method should only be used if the resulting array is expected to be small, as - * all the data is loaded into the driver's memory. - * * @param withReplacement whether sampling is done with replacement * @param num size of the returned sample * @param seed seed for the random number generator * @return sample of specified size in an array + * + * @note this method should only be used if the resulting array is expected to be small, as + * all the data is loaded into the driver's memory. */ def takeSample( withReplacement: Boolean, @@ -615,7 +624,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. */ def intersection(other: RDD[T]): RDD[T] = withScope { this.map(v => (v, null)).cogroup(other.map(v => (v, null))) @@ -627,7 +636,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param partitioner Partitioner to use for the resulting RDD */ @@ -643,7 +652,7 @@ abstract class RDD[T: ClassTag]( * Return the intersection of this RDD and another one. The output will not contain any duplicate * elements, even if the input RDDs did. Performs a hash partition across the cluster * - * Note that this method performs a shuffle internally. + * @note This method performs a shuffle internally. * * @param numPartitions How many partitions to use in the resulting RDD */ @@ -671,9 +680,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K)(implicit kt: ClassTag[K]): RDD[(K, Iterable[T])] = withScope { groupBy[K](f, defaultPartitioner(this)) @@ -684,9 +693,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K]( f: T => K, @@ -699,9 +708,9 @@ abstract class RDD[T: ClassTag]( * mapping to that key. The ordering of elements within each group is not guaranteed, and * may even differ each time the resulting RDD is evaluated. * - * Note: This operation may be very expensive. If you are grouping in order to perform an - * aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]] - * or [[PairRDDFunctions.reduceByKey]] will provide much better performance. + * @note This operation may be very expensive. If you are grouping in order to perform an + * aggregation (such as a sum or average) over each key, using `PairRDDFunctions.aggregateByKey` + * or `PairRDDFunctions.reduceByKey` will provide much better performance. */ def groupBy[K](f: T => K, p: Partitioner)(implicit kt: ClassTag[K], ord: Ordering[K] = null) : RDD[(K, Iterable[T])] = withScope { @@ -747,8 +756,10 @@ abstract class RDD[T: ClassTag]( * print line function (like out.println()) as the 2nd parameter. * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2) {f(e)} + * {{{ + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2) {f(e)} + * }}} * @param separateWorkingDir Use separate working directories for each task. * @param bufferSize Buffer size for the stdin writer for the piped process. * @param encoding Char encoding used for interacting (via stdin, stdout and stderr) with @@ -788,14 +799,26 @@ abstract class RDD[T: ClassTag]( } /** - * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a - * performance API to be used carefully only if we are sure that the RDD elements are + * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning. + * It is a performance API to be used carefully only if we are sure that the RDD elements are * serializable and don't require closure cleaning. * * @param preservesPartitioning indicates whether the input function preserves the partitioner, * which should be `false` unless this is a pair RDD and the input function doesn't modify * the keys. */ + private[spark] def mapPartitionsWithIndexInternal[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + new MapPartitionsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter), + preservesPartitioning) + } + + /** + * [performance] Spark's internal mapPartitions method that skips closure cleaning. + */ private[spark] def mapPartitionsInternal[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = withScope { @@ -906,7 +929,7 @@ abstract class RDD[T: ClassTag]( /** * Return an array that contains all of the elements in this RDD. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { @@ -919,7 +942,7 @@ abstract class RDD[T: ClassTag]( * * The iterator will consume as much memory as the largest partition in this RDD. * - * Note: this results in multiple Spark jobs, and if the input RDD is the result + * @note This results in multiple Spark jobs, and if the input RDD is the result * of a wide transformation (e.g. join with different partitioners), to avoid * recomputing the input RDD should be cached first. */ @@ -1167,10 +1190,15 @@ abstract class RDD[T: ClassTag]( /** * Return the count of each unique value in this RDD as a local map of (value, count) pairs. * - * Note that this method should only be used if the resulting map is expected to be small, as + * @note This method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. - * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which - * returns an RDD[T, Long] instead of a map. + * To handle very large results, consider using + * + * {{{ + * rdd.map(x => (x, 1L)).reduceByKey(_ + _) + * }}} + * + * , which returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = withScope { map(value => (value, null)).countByKey() @@ -1208,9 +1236,9 @@ abstract class RDD[T: ClassTag]( * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` - * would trigger sparse representation of registers, which may reduce the memory consumption - * and increase accuracy when the cardinality is small. + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero (`sp` is greater + * than `p`) would trigger sparse representation of registers, which may reduce the memory + * consumption and increase accuracy when the cardinality is small. * * @param p The precision value for the normal set. * `p` must be a value between 4 and `sp` if `sp` is not zero (32 max). @@ -1257,7 +1285,7 @@ abstract class RDD[T: ClassTag]( * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The index assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1271,7 +1299,7 @@ abstract class RDD[T: ClassTag]( * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. * - * Note that some RDDs, such as those returned by groupBy(), do not guarantee order of + * @note Some RDDs, such as those returned by groupBy(), do not guarantee order of * elements in a partition. The unique ID assigned to each element is therefore not guaranteed, * and may even change if the RDD is reevaluated. If a fixed ordering is required to guarantee * the same index assignments, you should sort the RDD with sortByKey() or save it to a file. @@ -1290,10 +1318,10 @@ abstract class RDD[T: ClassTag]( * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * - * @note due to complications in the internal implementation, this method will raise + * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { @@ -1355,7 +1383,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of top elements to return @@ -1378,7 +1406,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @note this method should only be used if the resulting array is expected to be small, as + * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @param num k, the number of elements to return @@ -1392,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { @@ -1423,7 +1451,7 @@ abstract class RDD[T: ClassTag]( } /** - * @note due to complications in the internal implementation, this method will raise an + * @note Due to complications in the internal implementation, this method will raise an * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`. * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.) @@ -1583,14 +1611,15 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD is checkpointed and materialized, either reliably or locally. */ - def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) + def isCheckpointed: Boolean = isCheckpointedAndMaterialized /** * Return whether this RDD is checkpointed and materialized, either reliably or locally. * This is introduced as an alias for `isCheckpointed` to clarify the semantics of the * return value. Exposed for testing. */ - private[spark] def isCheckpointedAndMaterialized: Boolean = isCheckpointed + private[spark] def isCheckpointedAndMaterialized: Boolean = + checkpointData.exists(_.isCheckpointed) /** * Return whether this RDD is marked for local checkpointing. @@ -1719,7 +1748,7 @@ abstract class RDD[T: ClassTag]( /** * Clears the dependencies of this RDD. This method must ensure that all references - * to the original parent RDDs is removed to enable the parent RDDs to be garbage + * to the original parent RDDs are removed to enable the parent RDDs to be garbage * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ @@ -1814,7 +1843,7 @@ abstract class RDD[T: ClassTag]( * Defines implicit functions that provide extra functionalities on RDDs of specific types. * * For example, [[RDD.rddToPairRDDFunctions]] converts an RDD into a [[PairRDDFunctions]] for - * key-value-pair RDDs, and enabling extra functionalities such as [[PairRDDFunctions.reduceByKey]]. + * key-value-pair RDDs, and enabling extra functionalities such as `PairRDDFunctions.reduceByKey`. */ object RDD { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 429514b4f6bee..6c552d4d12515 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,7 +23,8 @@ import org.apache.spark.Partition /** * Enumeration to manage state transitions of an RDD through checkpointing - * [ Initialized --> checkpointing in progress --> checkpointed ]. + * + * [ Initialized --{@literal >} checkpointing in progress --{@literal >} checkpointed ] */ private[spark] object CheckpointState extends Enumeration { type CheckpointState = Value @@ -32,7 +33,7 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, * as well as, manages the post-checkpoint state by providing the updated partitions, * iterator and preferred locations of the checkpointed RDD. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index eac901d10067c..37c67cee55f90 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} +import java.util.concurrent.TimeUnit import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CHECKPOINT_COMPRESS +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { originalRDD: RDD[T], checkpointDir: String, blockSize: Int = -1): ReliableCheckpointRDD[T] = { + val checkpointStartTimeNs = System.nanoTime() val sc = originalRDD.sparkContext @@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) } + val checkpointDurationMs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs) + logInfo(s"Checkpointing took $checkpointDurationMs ms.") + val newRDD = new ReliableCheckpointRDD[T]( sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { @@ -151,7 +159,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { } /** - * Write a RDD partition's data to a checkpoint file. + * Write an RDD partition's data to a checkpoint file. */ def writePartitionToCheckpointFile[T: ClassTag]( path: String, @@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging { val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) + val fileStream = fs.create(tempOutputPath, false, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream) + } else { + fileStream + } } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, @@ -239,12 +252,17 @@ private[spark] object ReliableCheckpointRDD extends Logging { val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration) val fileInputStream = fs.open(partitionerFilePath, bufferSize) val serializer = SparkEnv.get.serializer.newInstance() - val deserializeStream = serializer.deserializeStream(fileInputStream) - val partitioner = Utils.tryWithSafeFinally[Partitioner] { - deserializeStream.readObject[Partitioner] + val partitioner = Utils.tryWithSafeFinally { + val deserializeStream = serializer.deserializeStream(fileInputStream) + Utils.tryWithSafeFinally { + deserializeStream.readObject[Partitioner] + } { + deserializeStream.close() + } } { - deserializeStream.close() + fileInputStream.close() } + logDebug(s"Read partitioner from $partitionerFilePath") Some(partitioner) } catch { @@ -268,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) + val fileInputStream = { + val fileStream = fs.open(path, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream) + } else { + fileStream + } + } val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 1311b481c7c71..86a332790fb00 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -27,9 +27,10 @@ import org.apache.spark.internal.Logging /** * Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile, - * through an implicit conversion. Note that this can't be part of PairRDDFunctions because - * we need more implicit parameters to convert our keys and values to Writable. + * through an implicit conversion. * + * @note This can't be part of PairRDDFunctions because we need more implicit parameters to + * convert our keys and values to Writable. */ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)], diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 29d5d74650cdb..26eaa9aa3d03f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -25,10 +25,6 @@ import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index: Int = idx - - override def hashCode(): Int = index - - override def equals(other: Any): Boolean = super.equals(other) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index ad1fddbde7b00..60e383afadf1c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.collection.parallel.{ForkJoinTaskSupport, ThreadPoolTaskSupport} +import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool import scala.reflect.ClassTag diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index b0e5ba0865c63..8425b211d6ecf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -29,7 +29,7 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) } /** - * Represents a RDD zipped with its element indices. The ordering is first based on the partition + * Represents an RDD zipped with its element indices. The ordering is first based on the partition * index and then the ordering of items within each partition. So the first item in the first * partition gets index 0, and the last item in the last partition receives the largest index. * diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala index d8a80aa5aeb15..e00bc22aba44d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala +++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala @@ -35,14 +35,14 @@ trait PartitionCoalescer { * @param maxPartitions the maximum number of partitions to have after coalescing * @param parent the parent RDD whose partitions to coalesce * @return an array of [[PartitionGroup]]s, where each element is itself an array of - * [[Partition]]s and represents a partition after coalescing is performed. + * `Partition`s and represents a partition after coalescing is performed. */ def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] } /** * ::DeveloperApi:: - * A group of [[Partition]]s + * A group of `Partition`s * @param prefLoc preferred location for the partition group */ @DeveloperApi diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428e..ab72addb2466b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index f527ec86ab7b2..117f51c5b8f2a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc /** - * A callback that [[RpcEndpoint]] can use it to send back a message or failure. It's thread-safe + * A callback that [[RpcEndpoint]] can use to send back a message or failure. It's thread-safe * and can be called in any thread. */ private[spark] trait RpcCallContext { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 0ba95169529e6..97eed540b8f59 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -35,7 +35,7 @@ private[spark] trait RpcEnvFactory { * * The life-cycle of an endpoint is: * - * constructor -> onStart -> receive* -> onStop + * {@code constructor -> onStart -> receive* -> onStop} * * Note: `receive` can be called concurrently. If you want `receive` to be thread-safe, please use * [[ThreadSafeRpcEndpoint]] @@ -63,16 +63,16 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a - * unmatched message, [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.send` or `RpcCallContext.reply`. If receiving a + * unmatched message, `SparkException` will be thrown and sent to `onError`. */ def receive: PartialFunction[Any, Unit] = { case _ => throw new SparkException(self + " does not implement 'receive'") } /** - * Process messages from [[RpcEndpointRef.ask]]. If receiving a unmatched message, - * [[SparkException]] will be thrown and sent to `onError`. + * Process messages from `RpcEndpointRef.ask`. If receiving a unmatched message, + * `SparkException` will be thrown and sent to `onError`. */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala index b9db60a7797d8..fdbccc9e74c37 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala @@ -25,10 +25,11 @@ import org.apache.spark.SparkException * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only * connection and can only be reached via the client that sent the endpoint reference. * - * @param rpcAddress The socket address of the endpoint. + * @param rpcAddress The socket address of the endpoint. It's `null` when this address pointing to + * an endpoint in a client `NettyRpcEnv`. * @param name Name of the endpoint. */ -private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { +private[spark] case class RpcEndpointAddress(rpcAddress: RpcAddress, name: String) { require(name != null, "RpcEndpoint name must be provided.") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 994e18676ec49..4d39f144dd198 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -63,25 +63,21 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf) def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultAskTimeout) /** - * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default - * timeout, or throw a SparkException if this fails even after the default number of retries. - * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this - * method retries, the message handling in the receiver side should be idempotent. + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a + * default timeout, throw an exception if this fails. * * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. - * + * @param message the message to send * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any): T = askWithRetry(message, defaultAskTimeout) + def askSync[T: ClassTag](message: Any): T = askSync(message, defaultAskTimeout) /** - * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a - * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method - * retries, the message handling in the receiver side should be idempotent. + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] and get its result within a + * specified timeout, throw an exception if this fails. * * Note: this is a blocking action which may cost a lot of time, so don't call it in a message * loop of [[RpcEndpoint]]. @@ -91,33 +87,9 @@ private[spark] abstract class RpcEndpointRef(conf: SparkConf) * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithRetry[T: ClassTag](message: Any, timeout: RpcTimeout): T = { - // TODO: Consider removing multiple attempts - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = ask[T](message, timeout) - val result = timeout.awaitResult(future) - if (result == null) { - throw new SparkException("RpcEndpoint returned null") - } - return result - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - - if (attempts < maxRetries) { - Thread.sleep(retryWaitMs) - } - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) + def askSync[T: ClassTag](message: Any, timeout: RpcTimeout): T = { + val future = ask[T](message, timeout) + timeout.awaitResult(future) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 579122868afc8..530743c03640b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -146,7 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * @param uri URI with location of the file. */ def openChannel(uri: String): ReadableByteChannel - } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2761d39e37029..0557b7a3cc0b7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -19,15 +19,14 @@ package org.apache.spark.rpc import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ -import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.util.Utils +import org.apache.spark.SparkConf +import org.apache.spark.util.{ThreadUtils, Utils} /** - * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. + * An exception thrown if RpcTimeout modifies a `TimeoutException`. */ private[rpc] class RpcTimeoutException(message: String, cause: TimeoutException) extends TimeoutException(message) { initCause(cause) } @@ -72,15 +71,9 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S * is still not ready */ def awaitResult[T](future: Future[T]): T = { - val wrapAndRethrow: PartialFunction[Throwable, T] = { - case NonFatal(t) => - throw new SparkException("Exception thrown in awaitResult", t) - } try { - // scalastyle:off awaitresult - Await.result(future, duration) - // scalastyle:on awaitresult - } catch addMessageIfTimeout.orElse(wrapAndRethrow) + ThreadUtils.awaitResult(future, duration) + } catch addMessageIfTimeout } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e51649a1ecce9..b316e5443f639 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -33,12 +33,12 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ +import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.rpc._ -import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance, SerializationStream} +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, ThreadUtils, Utils} private[netty] class NettyRpcEnv( val conf: SparkConf, @@ -60,8 +60,8 @@ private[netty] class NettyRpcEnv( private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) + java.util.Arrays.asList(new AuthClientBootstrap(transportConf, + securityManager.getSaslUser(), securityManager)) } else { java.util.Collections.emptyList[TransportClientBootstrap] } @@ -111,7 +111,7 @@ private[netty] class NettyRpcEnv( def startServer(bindAddress: String, port: Int): Unit = { val bootstraps: java.util.List[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + java.util.Arrays.asList(new AuthServerBootstrap(transportConf, securityManager)) } else { java.util.Collections.emptyList() } @@ -189,7 +189,7 @@ private[netty] class NettyRpcEnv( } } else { // Message to a remote RPC endpoint. - postToOutbox(message.receiver, OneWayOutboxMessage(serialize(message))) + postToOutbox(message.receiver, OneWayOutboxMessage(message.serialize(this))) } } @@ -224,7 +224,7 @@ private[netty] class NettyRpcEnv( }(ThreadUtils.sameThread) dispatcher.postLocalMessage(message, p) } else { - val rpcMessage = RpcOutboxMessage(serialize(message), + val rpcMessage = RpcOutboxMessage(message.serialize(this), onFailure, (client, response) => onSuccess(deserialize[Any](client, response))) postToOutbox(message.receiver, rpcMessage) @@ -236,7 +236,8 @@ private[netty] class NettyRpcEnv( val timeoutCancelable = timeoutScheduler.schedule(new Runnable { override def run(): Unit = { - onFailure(new TimeoutException(s"Cannot receive any reply in ${timeout.duration}")) + onFailure(new TimeoutException(s"Cannot receive any reply from ${remoteAddr} " + + s"in ${timeout.duration}")) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) promise.future.onComplete { v => @@ -253,6 +254,13 @@ private[netty] class NettyRpcEnv( javaSerializerInstance.serialize(content) } + /** + * Returns [[SerializationStream]] that forwards the serialized bytes to `out`. + */ + private[netty] def serializeStream(out: OutputStream): SerializationStream = { + javaSerializerInstance.serializeStream(out) + } + private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = { NettyRpcEnv.currentClient.withValue(client) { deserialize { () => @@ -407,11 +415,9 @@ private[netty] class NettyRpcEnv( } } - } private[netty] object NettyRpcEnv extends Logging { - /** * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. * Use `currentEnv` to wrap the deserialization codes. E.g., @@ -482,16 +488,13 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { */ private[netty] class NettyRpcEndpointRef( @transient private val conf: SparkConf, - endpointAddress: RpcEndpointAddress, - @transient @volatile private var nettyEnv: NettyRpcEnv) - extends RpcEndpointRef(conf) with Serializable with Logging { + private val endpointAddress: RpcEndpointAddress, + @transient @volatile private var nettyEnv: NettyRpcEnv) extends RpcEndpointRef(conf) { @transient @volatile var client: TransportClient = _ - private val _address = if (endpointAddress.rpcAddress != null) endpointAddress else null - private val _name = endpointAddress.name - - override def address: RpcAddress = if (_address != null) _address.rpcAddress else null + override def address: RpcAddress = + if (endpointAddress.rpcAddress != null) endpointAddress.rpcAddress else null private def readObject(in: ObjectInputStream): Unit = { in.defaultReadObject() @@ -503,34 +506,103 @@ private[netty] class NettyRpcEndpointRef( out.defaultWriteObject() } - override def name: String = _name + override def name: String = endpointAddress.name override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - nettyEnv.ask(RequestMessage(nettyEnv.address, this, message), timeout) + nettyEnv.ask(new RequestMessage(nettyEnv.address, this, message), timeout) } override def send(message: Any): Unit = { require(message != null, "Message is null") - nettyEnv.send(RequestMessage(nettyEnv.address, this, message)) + nettyEnv.send(new RequestMessage(nettyEnv.address, this, message)) } - override def toString: String = s"NettyRpcEndpointRef(${_address})" - - def toURI: URI = new URI(_address.toString) + override def toString: String = s"NettyRpcEndpointRef(${endpointAddress})" final override def equals(that: Any): Boolean = that match { - case other: NettyRpcEndpointRef => _address == other._address + case other: NettyRpcEndpointRef => endpointAddress == other.endpointAddress case _ => false } - final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() + final override def hashCode(): Int = + if (endpointAddress == null) 0 else endpointAddress.hashCode() } /** * The message that is sent from the sender to the receiver. + * + * @param senderAddress the sender address. It's `null` if this message is from a client + * `NettyRpcEnv`. + * @param receiver the receiver of this message. + * @param content the message content. */ -private[netty] case class RequestMessage( - senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any) +private[netty] class RequestMessage( + val senderAddress: RpcAddress, + val receiver: NettyRpcEndpointRef, + val content: Any) { + + /** Manually serialize [[RequestMessage]] to minimize the size. */ + def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = { + val bos = new ByteBufferOutputStream() + val out = new DataOutputStream(bos) + try { + writeRpcAddress(out, senderAddress) + writeRpcAddress(out, receiver.address) + out.writeUTF(receiver.name) + val s = nettyEnv.serializeStream(out) + try { + s.writeObject(content) + } finally { + s.close() + } + } finally { + out.close() + } + bos.toByteBuffer + } + + private def writeRpcAddress(out: DataOutputStream, rpcAddress: RpcAddress): Unit = { + if (rpcAddress == null) { + out.writeBoolean(false) + } else { + out.writeBoolean(true) + out.writeUTF(rpcAddress.host) + out.writeInt(rpcAddress.port) + } + } + + override def toString: String = s"RequestMessage($senderAddress, $receiver, $content)" +} + +private[netty] object RequestMessage { + + private def readRpcAddress(in: DataInputStream): RpcAddress = { + val hasRpcAddress = in.readBoolean() + if (hasRpcAddress) { + RpcAddress(in.readUTF(), in.readInt()) + } else { + null + } + } + + def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = { + val bis = new ByteBufferInputStream(bytes) + val in = new DataInputStream(bis) + try { + val senderAddress = readRpcAddress(in) + val endpointAddress = RpcEndpointAddress(readRpcAddress(in), in.readUTF()) + val ref = new NettyRpcEndpointRef(nettyEnv.conf, endpointAddress, nettyEnv) + ref.client = client + new RequestMessage( + senderAddress, + ref, + // The remaining bytes in `bytes` are the message content. + nettyEnv.deserialize(client, bytes)) + } finally { + in.close() + } + } +} /** * A response that indicates some failure happens in the receiver side. @@ -576,10 +648,10 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) - val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) + val requestMessage = RequestMessage(nettyEnv, client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. - RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) } else { // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for // the listening address diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala index 6c090ada5ae9d..a7b7f58376f6b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala @@ -56,7 +56,7 @@ private[netty] case class RpcOutboxMessage( content: ByteBuffer, _onFailure: (Throwable) => Unit, _onSuccess: (TransportClient, ByteBuffer) => Unit) - extends OutboxMessage with RpcResponseCallback { + extends OutboxMessage with RpcResponseCallback with Logging { private var client: TransportClient = _ private var requestId: Long = _ @@ -67,8 +67,11 @@ private[netty] case class RpcOutboxMessage( } def onTimeout(): Unit = { - require(client != null, "TransportClient has not yet been set.") - client.removeRpcRequest(requestId) + if (client != null) { + client.removeRpcRequest(requestId) + } else { + logError("Ask timeout before connecting successfully") + } } override def onFailure(e: Throwable): Unit = { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala index 99f20da2d66aa..430dcc50ba711 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointVerifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** - * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an [[RpcEndpoint]] exists. + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if an `RpcEndpoint` exists. * * This is used when setting up a remote endpoint reference. */ @@ -35,6 +35,6 @@ private[netty] class RpcEndpointVerifier(override val rpcEnv: RpcEnv, dispatcher private[netty] object RpcEndpointVerifier { val NAME = "endpoint-verifier" - /** A message used to ask the remote [[RpcEndpointVerifier]] if an [[RpcEndpoint]] exists. */ + /** A message used to ask the remote [[RpcEndpointVerifier]] if an `RpcEndpoint` exists. */ case class CheckExistence(name: String) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala index cedacad44afec..0a5fe5a1d3ee1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AccumulableInfo.scala @@ -24,11 +24,6 @@ import org.apache.spark.annotation.DeveloperApi * :: DeveloperApi :: * Information about an [[org.apache.spark.Accumulable]] modified during a task or stage. * - * Note: once this is JSON serialized the types of `update` and `value` will be lost and be - * cast to strings. This is because the user can define an accumulator of any type and it will - * be difficult to preserve the type in consumers of the event log. This does not apply to - * internal accumulators that represent task level metrics. - * * @param id accumulator ID * @param name accumulator name * @param update partial value from a task, may be None if used on driver to describe a stage @@ -36,6 +31,11 @@ import org.apache.spark.annotation.DeveloperApi * @param internal whether this accumulator was internal * @param countFailedValues whether to count this accumulator's partial value if the task failed * @param metadata internal metadata associated with this accumulator, if any + * + * @note Once this is JSON serialized the types of `update` and `value` will be lost and be + * cast to strings. This is because the user can define an accumulator of any type and it will + * be difficult to preserve the type in consumers of the event log. This does not apply to + * internal accumulators that represent task level metrics. */ @DeveloperApi case class AccumulableInfo private[spark] ( diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 28c45d800ed06..6da8865cd10d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -34,6 +34,7 @@ private[spark] class ApplicationEventListener extends SparkListener { var adminAcls: Option[String] = None var viewAclsGroups: Option[String] = None var adminAclsGroups: Option[String] = None + var appSparkVersion: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { appName = Some(applicationStart.appName) @@ -57,4 +58,10 @@ private[spark] class ApplicationEventListener extends SparkListener { adminAclsGroups = allProperties.get("spark.admin.acls.groups") } } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + appSparkVersion = Some(sparkVersion) + case _ => + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala index fca4c6d37e446..e130e609e4f63 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/BlacklistTracker.scala @@ -17,10 +17,311 @@ package org.apache.spark.scheduler -import org.apache.spark.SparkConf +import java.util.concurrent.atomic.AtomicReference + +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.config -import org.apache.spark.util.Utils +import org.apache.spark.util.{Clock, SystemClock, Utils} + +/** + * BlacklistTracker is designed to track problematic executors and nodes. It supports blacklisting + * executors and nodes across an entire application (with a periodic expiry). TaskSetManagers add + * additional blacklisting of executors and nodes for individual tasks and stages which works in + * concert with the blacklisting here. + * + * The tracker needs to deal with a variety of workloads, eg.: + * + * * bad user code -- this may lead to many task failures, but that should not count against + * individual executors + * * many small stages -- this may prevent a bad executor for having many failures within one + * stage, but still many failures over the entire application + * * "flaky" executors -- they don't fail every task, but are still faulty enough to merit + * blacklisting + * + * See the design doc on SPARK-8425 for a more in-depth discussion. + * + * THREADING: As with most helpers of TaskSchedulerImpl, this is not thread-safe. Though it is + * called by multiple threads, callers must already have a lock on the TaskSchedulerImpl. The + * one exception is [[nodeBlacklist()]], which can be called without holding a lock. + */ +private[scheduler] class BlacklistTracker ( + private val listenerBus: LiveListenerBus, + conf: SparkConf, + allocationClient: Option[ExecutorAllocationClient], + clock: Clock = new SystemClock()) extends Logging { + + def this(sc: SparkContext, allocationClient: Option[ExecutorAllocationClient]) = { + this(sc.listenerBus, sc.conf, allocationClient) + } + + BlacklistTracker.validateBlacklistConfs(conf) + private val MAX_FAILURES_PER_EXEC = conf.get(config.MAX_FAILURES_PER_EXEC) + private val MAX_FAILED_EXEC_PER_NODE = conf.get(config.MAX_FAILED_EXEC_PER_NODE) + val BLACKLIST_TIMEOUT_MILLIS = BlacklistTracker.getBlacklistTimeout(conf) + + /** + * A map from executorId to information on task failures. Tracks the time of each task failure, + * so that we can avoid blacklisting executors due to failures that are very far apart. We do not + * actively remove from this as soon as tasks hit their timeouts, to avoid the time it would take + * to do so. But it will not grow too large, because as soon as an executor gets too many + * failures, we blacklist the executor and remove its entry here. + */ + private val executorIdToFailureList = new HashMap[String, ExecutorFailureList]() + val executorIdToBlacklistStatus = new HashMap[String, BlacklistedExecutor]() + val nodeIdToBlacklistExpiryTime = new HashMap[String, Long]() + /** + * An immutable copy of the set of nodes that are currently blacklisted. Kept in an + * AtomicReference to make [[nodeBlacklist()]] thread-safe. + */ + private val _nodeBlacklist = new AtomicReference[Set[String]](Set()) + /** + * Time when the next blacklist will expire. Used as a + * shortcut to avoid iterating over all entries in the blacklist when none will have expired. + */ + var nextExpiryTime: Long = Long.MaxValue + /** + * Mapping from nodes to all of the executors that have been blacklisted on that node. We do *not* + * remove from this when executors are removed from spark, so we can track when we get multiple + * successive blacklisted executors on one node. Nonetheless, it will not grow too large because + * there cannot be many blacklisted executors on one node, before we stop requesting more + * executors on that node, and we clean up the list of blacklisted executors once an executor has + * been blacklisted for BLACKLIST_TIMEOUT_MILLIS. + */ + val nodeToBlacklistedExecs = new HashMap[String, HashSet[String]]() + + /** + * Un-blacklists executors and nodes that have been blacklisted for at least + * BLACKLIST_TIMEOUT_MILLIS + */ + def applyBlacklistTimeout(): Unit = { + val now = clock.getTimeMillis() + // quickly check if we've got anything to expire from blacklist -- if not, avoid doing any work + if (now > nextExpiryTime) { + // Apply the timeout to blacklisted nodes and executors + val execsToUnblacklist = executorIdToBlacklistStatus.filter(_._2.expiryTime < now).keys + if (execsToUnblacklist.nonEmpty) { + // Un-blacklist any executors that have been blacklisted longer than the blacklist timeout. + logInfo(s"Removing executors $execsToUnblacklist from blacklist because the blacklist " + + s"for those executors has timed out") + execsToUnblacklist.foreach { exec => + val status = executorIdToBlacklistStatus.remove(exec).get + val failedExecsOnNode = nodeToBlacklistedExecs(status.node) + listenerBus.post(SparkListenerExecutorUnblacklisted(now, exec)) + failedExecsOnNode.remove(exec) + if (failedExecsOnNode.isEmpty) { + nodeToBlacklistedExecs.remove(status.node) + } + } + } + val nodesToUnblacklist = nodeIdToBlacklistExpiryTime.filter(_._2 < now).keys + if (nodesToUnblacklist.nonEmpty) { + // Un-blacklist any nodes that have been blacklisted longer than the blacklist timeout. + logInfo(s"Removing nodes $nodesToUnblacklist from blacklist because the blacklist " + + s"has timed out") + nodesToUnblacklist.foreach { node => + nodeIdToBlacklistExpiryTime.remove(node) + listenerBus.post(SparkListenerNodeUnblacklisted(now, node)) + } + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + } + updateNextExpiryTime() + } + } + + private def updateNextExpiryTime(): Unit = { + val execMinExpiry = if (executorIdToBlacklistStatus.nonEmpty) { + executorIdToBlacklistStatus.map{_._2.expiryTime}.min + } else { + Long.MaxValue + } + val nodeMinExpiry = if (nodeIdToBlacklistExpiryTime.nonEmpty) { + nodeIdToBlacklistExpiryTime.values.min + } else { + Long.MaxValue + } + nextExpiryTime = math.min(execMinExpiry, nodeMinExpiry) + } + + + def updateBlacklistForSuccessfulTaskSet( + stageId: Int, + stageAttemptId: Int, + failuresByExec: HashMap[String, ExecutorFailuresInTaskSet]): Unit = { + // if any tasks failed, we count them towards the overall failure count for the executor at + // this point. + val now = clock.getTimeMillis() + failuresByExec.foreach { case (exec, failuresInTaskSet) => + val appFailuresOnExecutor = + executorIdToFailureList.getOrElseUpdate(exec, new ExecutorFailureList) + appFailuresOnExecutor.addFailures(stageId, stageAttemptId, failuresInTaskSet) + appFailuresOnExecutor.dropFailuresWithTimeoutBefore(now) + val newTotal = appFailuresOnExecutor.numUniqueTaskFailures + + val expiryTimeForNewBlacklists = now + BLACKLIST_TIMEOUT_MILLIS + // If this pushes the total number of failures over the threshold, blacklist the executor. + // If its already blacklisted, we avoid "re-blacklisting" (which can happen if there were + // other tasks already running in another taskset when it got blacklisted), because it makes + // some of the logic around expiry times a little more confusing. But it also wouldn't be a + // problem to re-blacklist, with a later expiry time. + if (newTotal >= MAX_FAILURES_PER_EXEC && !executorIdToBlacklistStatus.contains(exec)) { + logInfo(s"Blacklisting executor id: $exec because it has $newTotal" + + s" task failures in successful task sets") + val node = failuresInTaskSet.node + executorIdToBlacklistStatus.put(exec, BlacklistedExecutor(node, expiryTimeForNewBlacklists)) + listenerBus.post(SparkListenerExecutorBlacklisted(now, exec, newTotal)) + executorIdToFailureList.remove(exec) + updateNextExpiryTime() + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(allocationClient) => + logInfo(s"Killing blacklisted executor id $exec " + + s"since spark.blacklist.killBlacklistedExecutors is set.") + allocationClient.killExecutors(Seq(exec), true, true) + case None => + logWarning(s"Not attempting to kill blacklisted executor id $exec " + + s"since allocation client is not defined.") + } + } + + // In addition to blacklisting the executor, we also update the data for failures on the + // node, and potentially put the entire node into a blacklist as well. + val blacklistedExecsOnNode = nodeToBlacklistedExecs.getOrElseUpdate(node, HashSet[String]()) + blacklistedExecsOnNode += exec + // If the node is already in the blacklist, we avoid adding it again with a later expiry + // time. + if (blacklistedExecsOnNode.size >= MAX_FAILED_EXEC_PER_NODE && + !nodeIdToBlacklistExpiryTime.contains(node)) { + logInfo(s"Blacklisting node $node because it has ${blacklistedExecsOnNode.size} " + + s"executors blacklisted: ${blacklistedExecsOnNode}") + nodeIdToBlacklistExpiryTime.put(node, expiryTimeForNewBlacklists) + listenerBus.post(SparkListenerNodeBlacklisted(now, node, blacklistedExecsOnNode.size)) + _nodeBlacklist.set(nodeIdToBlacklistExpiryTime.keySet.toSet) + if (conf.get(config.BLACKLIST_KILL_ENABLED)) { + allocationClient match { + case Some(allocationClient) => + logInfo(s"Killing all executors on blacklisted host $node " + + s"since spark.blacklist.killBlacklistedExecutors is set.") + if (allocationClient.killExecutorsOnHost(node) == false) { + logError(s"Killing executors on node $node failed.") + } + case None => + logWarning(s"Not attempting to kill executors on blacklisted host $node " + + s"since allocation client is not defined.") + } + } + } + } + } + } + + def isExecutorBlacklisted(executorId: String): Boolean = { + executorIdToBlacklistStatus.contains(executorId) + } + + /** + * Get the full set of nodes that are blacklisted. Unlike other methods in this class, this *IS* + * thread-safe -- no lock required on a taskScheduler. + */ + def nodeBlacklist(): Set[String] = { + _nodeBlacklist.get() + } + + def isNodeBlacklisted(node: String): Boolean = { + nodeIdToBlacklistExpiryTime.contains(node) + } + + def handleRemovedExecutor(executorId: String): Unit = { + // We intentionally do not clean up executors that are already blacklisted in + // nodeToBlacklistedExecs, so that if another executor on the same node gets blacklisted, we can + // blacklist the entire node. We also can't clean up executorIdToBlacklistStatus, so we can + // eventually remove the executor after the timeout. Despite not clearing those structures + // here, we don't expect they will grow too big since you won't get too many executors on one + // node, and the timeout will clear it up periodically in any case. + executorIdToFailureList -= executorId + } + + + /** + * Tracks all failures for one executor (that have not passed the timeout). + * + * In general we actually expect this to be extremely small, since it won't contain more than the + * maximum number of task failures before an executor is failed (default 2). + */ + private[scheduler] final class ExecutorFailureList extends Logging { + + private case class TaskId(stage: Int, stageAttempt: Int, taskIndex: Int) + + /** + * All failures on this executor in successful task sets. + */ + private var failuresAndExpiryTimes = ArrayBuffer[(TaskId, Long)]() + /** + * As an optimization, we track the min expiry time over all entries in failuresAndExpiryTimes + * so its quick to tell if there are any failures with expiry before the current time. + */ + private var minExpiryTime = Long.MaxValue + + def addFailures( + stage: Int, + stageAttempt: Int, + failuresInTaskSet: ExecutorFailuresInTaskSet): Unit = { + failuresInTaskSet.taskToFailureCountAndFailureTime.foreach { + case (taskIdx, (_, failureTime)) => + val expiryTime = failureTime + BLACKLIST_TIMEOUT_MILLIS + failuresAndExpiryTimes += ((TaskId(stage, stageAttempt, taskIdx), expiryTime)) + if (expiryTime < minExpiryTime) { + minExpiryTime = expiryTime + } + } + } + + /** + * The number of unique tasks that failed on this executor. Only counts failures within the + * timeout, and in successful tasksets. + */ + def numUniqueTaskFailures: Int = failuresAndExpiryTimes.size + + def isEmpty: Boolean = failuresAndExpiryTimes.isEmpty + + /** + * Apply the timeout to individual tasks. This is to prevent one-off failures that are very + * spread out in time (and likely have nothing to do with problems on the executor) from + * triggering blacklisting. However, note that we do *not* remove executors and nodes from + * the blacklist as we expire individual task failures -- each have their own timeout. Eg., + * suppose: + * * timeout = 10, maxFailuresPerExec = 2 + * * Task 1 fails on exec 1 at time 0 + * * Task 2 fails on exec 1 at time 5 + * --> exec 1 is blacklisted from time 5 - 15. + * This is to simplify the implementation, as well as keep the behavior easier to understand + * for the end user. + */ + def dropFailuresWithTimeoutBefore(dropBefore: Long): Unit = { + if (minExpiryTime < dropBefore) { + var newMinExpiry = Long.MaxValue + val newFailures = new ArrayBuffer[(TaskId, Long)] + failuresAndExpiryTimes.foreach { case (task, expiryTime) => + if (expiryTime >= dropBefore) { + newFailures += ((task, expiryTime)) + if (expiryTime < newMinExpiry) { + newMinExpiry = expiryTime + } + } + } + failuresAndExpiryTimes = newFailures + minExpiryTime = newMinExpiry + } + } + + override def toString(): String = { + s"failures = $failuresAndExpiryTimes" + } + } + +} private[scheduler] object BlacklistTracker extends Logging { @@ -80,7 +381,9 @@ private[scheduler] object BlacklistTracker extends Logging { config.MAX_TASK_ATTEMPTS_PER_EXECUTOR, config.MAX_TASK_ATTEMPTS_PER_NODE, config.MAX_FAILURES_PER_EXEC_STAGE, - config.MAX_FAILED_EXEC_PER_NODE_STAGE + config.MAX_FAILED_EXEC_PER_NODE_STAGE, + config.MAX_FAILURES_PER_EXEC, + config.MAX_FAILED_EXEC_PER_NODE ).foreach { config => val v = conf.get(config) if (v <= 0) { @@ -112,3 +415,5 @@ private[scheduler] object BlacklistTracker extends Logging { } } } + +private final case class BlacklistedExecutor(node: String, expiryTime: Long) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f2517401cb76b..68178c7fb3bb1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -58,7 +58,7 @@ import org.apache.spark.util._ * set of map output files, and another to read those files after a barrier). In the end, every * stage will have only shuffle dependencies on other stages, and may compute multiple operations * inside it. The actual pipelining of these operations happens in the RDD.compute() functions of - * various RDDs (MappedRDD, FilteredRDD, etc). + * various RDDs * * In addition to coming up with a DAG of stages, the DAGScheduler also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the @@ -187,6 +187,13 @@ class DAGScheduler( /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + /** + * Number of consecutive stage attempts allowed before a stage is aborted. + */ + private[scheduler] val maxConsecutiveStageAttempts = + sc.getConf.getInt("spark.stage.maxConsecutiveAttempts", + DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS) + private val messageScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("dag-scheduler-message") @@ -232,7 +239,7 @@ class DAGScheduler( accumUpdates: Array[(Long, Int, Int, Seq[AccumulableInfo])], blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates)) - blockManagerMaster.driverEndpoint.askWithRetry[Boolean]( + blockManagerMaster.driverEndpoint.askSync[Boolean]( BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) } @@ -600,7 +607,7 @@ class DAGScheduler( * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name * - * @throws Exception when the job fails + * @note Throws `Exception` when the job fails */ def runJob[T, U]( rdd: RDD[T], @@ -637,7 +644,7 @@ class DAGScheduler( * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD - * @param evaluator [[ApproximateEvaluator]] to receive the partial results + * @param evaluator `ApproximateEvaluator` to receive the partial results * @param callSite where in the user program this job was called * @param timeout maximum time to wait for the job, in milliseconds * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name @@ -696,9 +703,9 @@ class DAGScheduler( /** * Cancel a job that is running or waiting in the queue. */ - def cancelJob(jobId: Int): Unit = { + def cancelJob(jobId: Int, reason: Option[String]): Unit = { logInfo("Asked to cancel job " + jobId) - eventProcessLoop.post(JobCancelled(jobId)) + eventProcessLoop.post(JobCancelled(jobId, reason)) } /** @@ -719,7 +726,7 @@ class DAGScheduler( private[scheduler] def doCancelAllJobs() { // Cancel all running jobs. runningStages.map(_.firstJobId).foreach(handleJobCancellation(_, - reason = "as part of cancellation of all jobs")) + Option("as part of cancellation of all jobs"))) activeJobs.clear() // These should already be empty by this point, jobIdToActiveJob.clear() // but just in case we lost track of some jobs... } @@ -727,8 +734,17 @@ class DAGScheduler( /** * Cancel all jobs associated with a running or scheduled stage. */ - def cancelStage(stageId: Int) { - eventProcessLoop.post(StageCancelled(stageId)) + def cancelStage(stageId: Int, reason: Option[String]) { + eventProcessLoop.post(StageCancelled(stageId, reason)) + } + + /** + * Kill a given task. It will be retried. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + taskScheduler.killTaskAttempt(taskId, interruptThread, reason) } /** @@ -785,7 +801,8 @@ class DAGScheduler( } } val jobIds = activeInGroup.map(_.jobId) - jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) + jobIds.foreach(handleJobCancellation(_, + Option("part of cancelled job group %s".format(groupId)))) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) { @@ -931,8 +948,6 @@ class DAGScheduler( /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") - // Get our pending tasks and remember them in our pendingTasks entry - stage.pendingPartitions.clear() // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() @@ -1009,13 +1024,16 @@ class DAGScheduler( } val tasks: Seq[Task[_]] = try { + val serializedTaskMetrics = closureSerializer.serialize(stage.latestInfo.taskMetrics).array() stage match { case stage: ShuffleMapStage => + stage.pendingPartitions.clear() partitionsToCompute.map { id => val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) + stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.latestInfo.taskMetrics, properties, Option(jobId), + taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } @@ -1025,7 +1043,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics, + taskBinary, part, locs, id, properties, serializedTaskMetrics, Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) } } @@ -1037,9 +1055,8 @@ class DAGScheduler( } if (tasks.size > 0) { - logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")") - stage.pendingPartitions ++= tasks.map(_.partitionId) - logDebug("New pending partitions: " + stage.pendingPartitions) + logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) @@ -1089,7 +1106,8 @@ class DAGScheduler( // To avoid UI cruft, ignore cases where value wasn't updated if (acc.name.isDefined && !updates.isZero) { stage.latestInfo.accumulables(id) = acc.toInfo(None, Some(acc.value)) - event.taskInfo.accumulables += acc.toInfo(Some(updates.value), Some(acc.value)) + event.taskInfo.setAccumulables( + acc.toInfo(Some(updates.value), Some(acc.value)) +: event.taskInfo.accumulables) } } } catch { @@ -1144,7 +1162,6 @@ class DAGScheduler( val stage = stageIdToStage(task.stageId) event.reason match { case Success => - stage.pendingPartitions -= task.partitionId task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1184,10 +1201,29 @@ class DAGScheduler( val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) + if (stageIdToStage(task.stageId).latestInfo.attemptId == task.stageAttemptId) { + // This task was for the currently running attempt of the stage. Since the task + // completed successfully from the perspective of the TaskSetManager, mark it as + // no longer pending (the TaskSetManager may consider the task complete even + // when the output needs to be ignored because the task's epoch is too small below. + // In this case, when pending partitions is empty, there will still be missing + // output locations, which will cause the DAGScheduler to resubmit the stage below.) + shuffleStage.pendingPartitions -= task.partitionId + } if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { + // The epoch of the task is acceptable (i.e., the task was launched after the most + // recent failure we're aware of for the executor), so mark the task's output as + // available. shuffleStage.addOutputLoc(smt.partitionId, status) + // Remove the task's partition from pending partitions. This may have already been + // done above, but will not have been done yet in cases where the task attempt was + // from an earlier attempt of the stage (i.e., not the attempt that's currently + // running). This allows the DAGScheduler to mark the stage as complete when one + // copy of each task has finished successfully, even if the currently active stage + // still has tasks running. + shuffleStage.pendingPartitions -= task.partitionId } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { @@ -1211,7 +1247,7 @@ class DAGScheduler( clearCacheLocs() if (!shuffleStage.isAvailable) { - // Some tasks had failed; let's resubmit this shuffleStage + // Some tasks had failed; let's resubmit this shuffleStage. // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + ") because some of its tasks had failed: " + @@ -1232,7 +1268,14 @@ class DAGScheduler( case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") - stage.pendingPartitions += task.partitionId + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + assert(false, "TaskSetManagers should only send Resubmitted task statuses for " + + "tasks in ShuffleMapStages.") + } case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) @@ -1255,27 +1298,47 @@ class DAGScheduler( s"longer running") } - if (disallowStageRetryForTest) { - abortStage(failedStage, "Fetch failure will not retry stage due to testing config", - None) - } else if (failedStage.failedOnFetchAndShouldAbort(task.stageAttemptId)) { - abortStage(failedStage, s"$failedStage (${failedStage.name}) " + - s"has failed the maximum allowable number of " + - s"times: ${Stage.MAX_CONSECUTIVE_FETCH_FAILURES}. " + - s"Most recent failure reason: ${failureMessage}", None) - } else { - if (failedStages.isEmpty) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. - // TODO: Cancel running tasks in the stage - logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure") - messageScheduler.schedule(new Runnable { - override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) - }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Fetch failure will not retry stage due to testing config" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $failureMessage""".stripMargin.replaceAll("\n", " ") } + abortStage(failedStage, abortMessage, None) + } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued + // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064 + val noResubmitEnqueued = !failedStages.contains(failedStage) failedStages += failedStage failedStages += mapStage + if (noResubmitEnqueued) { + // We expect one executor failure to trigger many FetchFailures in rapid succession, + // but all of those task failures can typically be handled by a single resubmission of + // the failed stage. We avoid flooding the scheduler's event queue with resubmit + // messages by checking whether a resubmit is already in the event queue for the + // failed stage. If there is already a resubmit enqueued for a different failed + // stage, that event would also be sufficient to handle the current failed stage, but + // producing a resubmit for each failed stage makes debugging and logging a little + // simpler while not producing an overwhelming number of scheduler events. + logInfo( + s"Resubmitting $mapStage (${mapStage.name}) and " + + s"$failedStage (${failedStage.name}) due to fetch failure" + ) + messageScheduler.schedule( + new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, + DAGScheduler.RESUBMIT_TIMEOUT, + TimeUnit.MILLISECONDS + ) + } } // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { @@ -1299,7 +1362,7 @@ class DAGScheduler( case TaskResultLost => // Do nothing here; the TaskScheduler handles these failures and resubmits the task. - case _: ExecutorLostFailure | TaskKilled | UnknownReason => + case _: ExecutorLostFailure | _: TaskKilled | UnknownReason => // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler // will abort the job. } @@ -1356,24 +1419,30 @@ class DAGScheduler( } } - private[scheduler] def handleStageCancellation(stageId: Int) { + private[scheduler] def handleStageCancellation(stageId: Int, reason: Option[String]) { stageIdToStage.get(stageId) match { case Some(stage) => val jobsThatUseStage: Array[Int] = stage.jobIds.toArray jobsThatUseStage.foreach { jobId => - handleJobCancellation(jobId, s"because Stage $stageId was cancelled") + val reasonStr = reason match { + case Some(originalReason) => + s"because $originalReason" + case None => + s"because Stage $stageId was cancelled" + } + handleJobCancellation(jobId, Option(reasonStr)) } case None => logInfo("No active jobs to kill for Stage " + stageId) } } - private[scheduler] def handleJobCancellation(jobId: Int, reason: String = "") { + private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]) { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { failJobAndIndependentStages( - jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason)) + jobIdToActiveJob(jobId), "Job %d cancelled %s".format(jobId, reason.getOrElse(""))) } } @@ -1615,11 +1684,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case MapStageSubmitted(jobId, dependency, callSite, listener, properties) => dagScheduler.handleMapStageSubmitted(jobId, dependency, callSite, listener, properties) - case StageCancelled(stageId) => - dagScheduler.handleStageCancellation(stageId) + case StageCancelled(stageId, reason) => + dagScheduler.handleStageCancellation(stageId, reason) - case JobCancelled(jobId) => - dagScheduler.handleJobCancellation(jobId) + case JobCancelled(jobId, reason) => + dagScheduler.handleJobCancellation(jobId, reason) case JobGroupCancelled(groupId) => dagScheduler.handleJobGroupCancelled(groupId) @@ -1660,7 +1729,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler } catch { case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t) } - dagScheduler.sc.stop() + dagScheduler.sc.stopInNewThread() } override def onStop(): Unit = { @@ -1674,4 +1743,7 @@ private[spark] object DAGScheduler { // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in val RESUBMIT_TIMEOUT = 200 + + // Number of consecutive stage attempts allowed before a stage is aborted + val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 03781a2a2b56c..cda0585f154a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -53,9 +53,15 @@ private[scheduler] case class MapStageSubmitted( properties: Properties = null) extends DAGSchedulerEvent -private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent +private[scheduler] case class StageCancelled( + stageId: Int, + reason: Option[String]) + extends DAGSchedulerEvent -private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent +private[scheduler] case class JobCancelled( + jobId: Int, + reason: Option[String]) + extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index ce7877469f03f..f481436332249 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.util.Locale import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -118,7 +119,7 @@ private[spark] class EventLoggingListener( val cstream = compressionCodec.map(_.compressedOutputStream(dstream)).getOrElse(dstream) val bstream = new BufferedOutputStream(cstream, outputBufferSize) - EventLoggingListener.initEventLog(bstream) + EventLoggingListener.initEventLog(bstream, testing, loggedEvents) fileSystem.setPermission(path, LOG_FILE_PERMISSIONS) writer = Some(new PrintWriter(bstream)) logInfo("Logging events to %s".format(logPath)) @@ -153,7 +154,9 @@ private[spark] class EventLoggingListener( override def onTaskEnd(event: SparkListenerTaskEnd): Unit = logEvent(event) - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = logEvent(event) + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { + logEvent(redactEvent(event)) + } // Events that trigger a flush override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { @@ -191,6 +194,22 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } + override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = { + logEvent(event, flushLogger = true) + } + + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { + logEvent(event, flushLogger = true) + } + + override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = { + logEvent(event, flushLogger = true) + } + + override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = { + logEvent(event, flushLogger = true) + } + // No-op because logging every update would be overkill override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} @@ -231,6 +250,21 @@ private[spark] class EventLoggingListener( } } + private[spark] def redactEvent( + event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { + // environmentDetails maps a string descriptor to a set of properties + // Similar to: + // "JVM Information" -> jvmInformation, + // "Spark Properties" -> sparkProperties, + // ... + // where jvmInformation, sparkProperties, etc. are sequence of tuples. + // We go through the various of properties and redact sensitive information from them. + val redactedProps = event.environmentDetails.map{ case (name, props) => + name -> Utils.redact(sparkConf, props) + } + SparkListenerEnvironmentUpdate(redactedProps) + } + } private[spark] object EventLoggingListener extends Logging { @@ -249,10 +283,17 @@ private[spark] object EventLoggingListener extends Logging { * * @param logStream Raw output stream to the event log file. */ - def initEventLog(logStream: OutputStream): Unit = { + def initEventLog( + logStream: OutputStream, + testing: Boolean, + loggedEvents: ArrayBuffer[JValue]): Unit = { val metadata = SparkListenerLogStart(SPARK_VERSION) - val metadataJson = compact(JsonProtocol.logStartToJson(metadata)) + "\n" + val eventJson = JsonProtocol.logStartToJson(metadata) + val metadataJson = compact(eventJson) + "\n" logStream.write(metadataJson.getBytes(StandardCharsets.UTF_8)) + if (testing && loggedEvents != null) { + loggedEvents += eventJson + } } /** @@ -289,7 +330,7 @@ private[spark] object EventLoggingListener extends Logging { } private def sanitize(str: String): String = { - str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase + str.replaceAll("[ :/]", "-").replaceAll("[.${}'\"]", "_").toLowerCase(Locale.ROOT) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala index 20ab27d127aba..70553d8be28b5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorFailuresInTaskSet.scala @@ -25,26 +25,30 @@ import scala.collection.mutable.HashMap private[scheduler] class ExecutorFailuresInTaskSet(val node: String) { /** * Mapping from index of the tasks in the taskset, to the number of times it has failed on this - * executor. + * executor and the most recent failure time. */ - val taskToFailureCount = HashMap[Int, Int]() + val taskToFailureCountAndFailureTime = HashMap[Int, (Int, Long)]() - def updateWithFailure(taskIndex: Int): Unit = { - val prevFailureCount = taskToFailureCount.getOrElse(taskIndex, 0) - taskToFailureCount(taskIndex) = prevFailureCount + 1 + def updateWithFailure(taskIndex: Int, failureTime: Long): Unit = { + val (prevFailureCount, prevFailureTime) = + taskToFailureCountAndFailureTime.getOrElse(taskIndex, (0, -1L)) + // these times always come from the driver, so we don't need to worry about skew, but might + // as well still be defensive in case there is non-monotonicity in the clock + val newFailureTime = math.max(prevFailureTime, failureTime) + taskToFailureCountAndFailureTime(taskIndex) = (prevFailureCount + 1, newFailureTime) } - def numUniqueTasksWithFailures: Int = taskToFailureCount.size + def numUniqueTasksWithFailures: Int = taskToFailureCountAndFailureTime.size /** * Return the number of times this executor has failed on the given task index. */ def getNumTaskFailures(index: Int): Int = { - taskToFailureCount.getOrElse(index, 0) + taskToFailureCountAndFailureTime.getOrElse(index, (0, 0))._1 } override def toString(): String = { s"numUniqueTasksWithFailures = $numUniqueTasksWithFailures; " + - s"tasksToFailureCount = $taskToFailureCount" + s"tasksToFailureCount = $taskToFailureCountAndFailureTime" } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala index d1ac7131baba5..47f3527a32c01 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala @@ -42,7 +42,7 @@ private[spark] trait ExternalClusterManager { /** * Create a scheduler backend for the given SparkContext and scheduler. This is - * called after task scheduler is created using [[ExternalClusterManager.createTaskScheduler()]]. + * called after task scheduler is created using `ExternalClusterManager.createTaskScheduler()`. * @param sc SparkContext * @param masterURL the master URL * @param scheduler TaskScheduler that will be used with the scheduler backend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index a6b032cc0084c..66ab9a52b7781 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -153,7 +153,7 @@ object InputFormatInfo { a) For each host, count number of splits hosted on that host. b) Decrement the currently allocated containers on that host. - c) Compute rack info for each host and update rack -> count map based on (b). + c) Compute rack info for each host and update rack to count map based on (b). d) Allocate nodes based on (c) e) On the allocation result, ensure that we don't allocate "too many" jobs on a single node (even if data locality on that is very high) : this is to prevent fragility of job if a diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 9012289f047c5..65d7184231e24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -50,7 +50,7 @@ private[spark] class JobWaiter[T]( * will fail this job with a SparkException. */ def cancel() { - dagScheduler.cancelJob(jobId) + dagScheduler.cancelJob(jobId, None) } override def taskSucceeded(index: Int, result: Any): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 7bed6851d0cde..83d87b548a430 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.util.{RpcUtils, ThreadUtils} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -47,25 +48,29 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private type StageId = Int private type PartitionId = Int private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + private case class StageState(numPartitions: Int) { + val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) + val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + } /** - * Map from active stages's id => partition id => task attempt with exclusive lock on committing - * output for that partition. + * Map from active stages's id => authorized task attempts for each partition id, which hold an + * exclusive lock on committing task output for that partition, as well as any known failed + * attempts in the stage. * * Entries are added to the top-level map when stages start and are removed they finish * (either successfully or unsuccessfully). * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]() + private val stageStates = mutable.Map[StageId, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. */ def isEmpty: Boolean = { - authorizedCommittersByStage.isEmpty + stageStates.isEmpty } /** @@ -88,7 +93,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber) coordinatorRef match { case Some(endpointRef) => - endpointRef.askWithRetry[Boolean](msg) + ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg), + RpcUtils.askRpcTimeout(conf).duration) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") @@ -103,19 +109,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart( - stage: StageId, - maxPartitionId: Int): Unit = { - val arr = new Array[TaskAttemptNumber](maxPartitionId + 1) - java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER) - synchronized { - authorizedCommittersByStage(stage) = arr - } + private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { + stageStates(stage) = new StageState(maxPartitionId + 1) } // Called by DAGScheduler private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { - authorizedCommittersByStage.remove(stage) + stageStates.remove(stage) } // Called by DAGScheduler @@ -124,7 +124,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) partition: PartitionId, attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { - val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") return }) @@ -135,10 +135,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters(partition) == attemptNumber) { + // Mark the attempt as failed to blacklist from future commit protocol + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber + if (stageState.authorizedCommitters(partition) == attemptNumber) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER } } } @@ -147,7 +149,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) if (isDriver) { coordinatorRef.foreach(_ send StopCoordinator) coordinatorRef = None - authorizedCommittersByStage.clear() + stageStates.clear() } } @@ -156,25 +158,45 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) stage: StageId, partition: PartitionId, attemptNumber: TaskAttemptNumber): Boolean = synchronized { - authorizedCommittersByStage.get(stage) match { - case Some(authorizedCommitters) => - authorizedCommitters(partition) match { + stageStates.get(stage) match { + case Some(state) if attemptFailed(state, partition, attemptNumber) => + logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + + s" partition=$partition as task attempt $attemptNumber has already failed.") + false + case Some(state) => + state.authorizedCommitters(partition) match { case NO_AUTHORIZED_COMMITTER => logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + s"partition=$partition") - authorizedCommitters(partition) = attemptNumber + state.authorizedCommitters(partition) = attemptNumber true case existingCommitter => - logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + - s"partition=$partition; existingCommitter = $existingCommitter") - false + // Coordinator should be idempotent when receiving AskPermissionToCommit. + if (existingCommitter == attemptNumber) { + logWarning(s"Authorizing duplicate request to commit for " + + s"attemptNumber=$attemptNumber to commit for stage=$stage," + + s" partition=$partition; existingCommitter = $existingCommitter." + + s" This can indicate dropped network traffic.") + true + } else { + logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " + + s"partition=$partition; existingCommitter = $existingCommitter") + false + } } case None => - logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + - s"partition $partition to commit") + logDebug(s"Stage $stage has completed, so not allowing" + + s" attempt number $attemptNumber of partition $partition to commit") false } } + + private def attemptFailed( + stageState: StageState, + partition: PartitionId, + attempt: TaskAttemptNumber): Boolean = synchronized { + stageState.failures.get(partition).exists(_.contains(attempt)) + } } private[spark] object OutputCommitCoordinator { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 2a69a6c5e8790..1181371ab425a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -37,24 +37,24 @@ private[spark] class Pool( val schedulableQueue = new ConcurrentLinkedQueue[Schedulable] val schedulableNameToSchedulable = new ConcurrentHashMap[String, Schedulable] - var weight = initWeight - var minShare = initMinShare + val weight = initWeight + val minShare = initMinShare var runningTasks = 0 - var priority = 0 + val priority = 0 // A pool's stage id is used to break the tie in scheduling. var stageId = -1 - var name = poolName + val name = poolName var parent: Pool = null - var taskSetSchedulingAlgorithm: SchedulingAlgorithm = { + private val taskSetSchedulingAlgorithm: SchedulingAlgorithm = { schedulingMode match { case SchedulingMode.FAIR => new FairSchedulingAlgorithm() case SchedulingMode.FIFO => new FIFOSchedulingAlgorithm() case _ => - val msg = "Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." + val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." throw new IllegalArgumentException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index 0bd5a6bc59a9e..08e05ae0c095b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -22,6 +22,7 @@ import java.io.{InputStream, IOException} import scala.io.Source import com.fasterxml.jackson.core.JsonParseException +import com.fasterxml.jackson.databind.exc.UnrecognizedPropertyException import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging @@ -87,6 +88,12 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging { // Ignore events generated by Structured Streaming in Spark 2.0.0 and 2.0.1. // It's safe since no place uses them. logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") + case e: UnrecognizedPropertyException if e.getMessage != null && e.getMessage.startsWith( + "Unrecognized field \"queryStatus\" " + + "(class org.apache.spark.sql.streaming.StreamingQueryListener$") => + // Ignore events generated by Structured Streaming in Spark 2.0.2 + // It's safe since no place uses them. + logWarning(s"Dropped incompatible Structured Streaming log: $currentLine") case jpe: JsonParseException => // We can only ignore exception from last line of the file that might be truncated // the last entry may not be the very last line in the event log, but we treat it diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 1e7c63af2e797..e36c759a42556 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -24,7 +24,6 @@ import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD /** @@ -42,7 +41,8 @@ import org.apache.spark.rdd.RDD * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. + * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side + * and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to @@ -57,12 +57,12 @@ private[spark] class ResultTask[T, U]( locs: Seq[TaskLocation], val outputId: Int, localProperties: Properties, - metrics: TaskMetrics, + serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, appAttemptId: Option[String] = None) - extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, - appId, appAttemptId) + extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, + jobId, appId, appAttemptId) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 96325a0329f89..5f3c280ec31ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -18,12 +18,14 @@ package org.apache.spark.scheduler import java.io.{FileInputStream, InputStream} -import java.util.{NoSuchElementException, Properties} +import java.util.{Locale, NoSuchElementException, Properties} -import scala.xml.XML +import scala.util.control.NonFatal +import scala.xml.{Node, XML} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils /** @@ -54,7 +56,8 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) extends SchedulableBuilder with Logging { - val schedulerAllocFile = conf.getOption("spark.scheduler.allocation.file") + val SCHEDULER_ALLOCATION_FILE_PROPERTY = "spark.scheduler.allocation.file" + val schedulerAllocFile = conf.getOption(SCHEDULER_ALLOCATION_FILE_PROPERTY) val DEFAULT_SCHEDULER_FILE = "fairscheduler.xml" val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.pool" val DEFAULT_POOL_NAME = "default" @@ -68,19 +71,35 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) val DEFAULT_WEIGHT = 1 override def buildPools() { - var is: Option[InputStream] = None + var fileData: Option[(InputStream, String)] = None try { - is = Option { - schedulerAllocFile.map { f => - new FileInputStream(f) - }.getOrElse { - Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE) + fileData = schedulerAllocFile.map { f => + val fis = new FileInputStream(f) + logInfo(s"Creating Fair Scheduler pools from $f") + Some((fis, f)) + }.getOrElse { + val is = Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE) + if (is != null) { + logInfo(s"Creating Fair Scheduler pools from default file: $DEFAULT_SCHEDULER_FILE") + Some((is, DEFAULT_SCHEDULER_FILE)) + } else { + logWarning("Fair Scheduler configuration file not found so jobs will be scheduled in " + + s"FIFO order. To use fair scheduling, configure pools in $DEFAULT_SCHEDULER_FILE or " + + s"set $SCHEDULER_ALLOCATION_FILE_PROPERTY to a file that contains the configuration.") + None } } - is.foreach { i => buildFairSchedulerPool(i) } + fileData.foreach { case (is, fileName) => buildFairSchedulerPool(is, fileName) } + } catch { + case NonFatal(t) => + val defaultMessage = "Error while building the fair scheduler pools" + val message = fileData.map { case (is, fileName) => s"$defaultMessage from $fileName" } + .getOrElse(defaultMessage) + logError(message, t) + throw t } finally { - is.foreach(_.close()) + fileData.foreach { case (is, fileName) => is.close() } } // finally create "default" pool @@ -92,63 +111,93 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) val pool = new Pool(DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) rootPool.addSchedulable(pool) - logInfo("Created default pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( + logInfo("Created default pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( DEFAULT_POOL_NAME, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) } } - private def buildFairSchedulerPool(is: InputStream) { + private def buildFairSchedulerPool(is: InputStream, fileName: String) { val xml = XML.load(is) for (poolNode <- (xml \\ POOLS_PROPERTY)) { val poolName = (poolNode \ POOL_NAME_PROPERTY).text - var schedulingMode = DEFAULT_SCHEDULING_MODE - var minShare = DEFAULT_MINIMUM_SHARE - var weight = DEFAULT_WEIGHT - - val xmlSchedulingMode = (poolNode \ SCHEDULING_MODE_PROPERTY).text - if (xmlSchedulingMode != "") { - try { - schedulingMode = SchedulingMode.withName(xmlSchedulingMode) - } catch { - case e: NoSuchElementException => - logWarning(s"Unsupported schedulingMode: $xmlSchedulingMode, " + - s"using the default schedulingMode: $schedulingMode") - } - } - val xmlMinShare = (poolNode \ MINIMUM_SHARES_PROPERTY).text - if (xmlMinShare != "") { - minShare = xmlMinShare.toInt - } + val schedulingMode = getSchedulingModeValue(poolNode, poolName, + DEFAULT_SCHEDULING_MODE, fileName) + val minShare = getIntValue(poolNode, poolName, MINIMUM_SHARES_PROPERTY, + DEFAULT_MINIMUM_SHARE, fileName) + val weight = getIntValue(poolNode, poolName, WEIGHT_PROPERTY, + DEFAULT_WEIGHT, fileName) - val xmlWeight = (poolNode \ WEIGHT_PROPERTY).text - if (xmlWeight != "") { - weight = xmlWeight.toInt - } + rootPool.addSchedulable(new Pool(poolName, schedulingMode, minShare, weight)) - val pool = new Pool(poolName, schedulingMode, minShare, weight) - rootPool.addSchedulable(pool) - logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( + logInfo("Created pool: %s, schedulingMode: %s, minShare: %d, weight: %d".format( poolName, schedulingMode, minShare, weight)) } } + private def getSchedulingModeValue( + poolNode: Node, + poolName: String, + defaultValue: SchedulingMode, + fileName: String): SchedulingMode = { + + val xmlSchedulingMode = + (poolNode \ SCHEDULING_MODE_PROPERTY).text.trim.toUpperCase(Locale.ROOT) + val warningMessage = s"Unsupported schedulingMode: $xmlSchedulingMode found in " + + s"Fair Scheduler configuration file: $fileName, using " + + s"the default schedulingMode: $defaultValue for pool: $poolName" + try { + if (SchedulingMode.withName(xmlSchedulingMode) != SchedulingMode.NONE) { + SchedulingMode.withName(xmlSchedulingMode) + } else { + logWarning(warningMessage) + defaultValue + } + } catch { + case e: NoSuchElementException => + logWarning(warningMessage) + defaultValue + } + } + + private def getIntValue( + poolNode: Node, + poolName: String, + propertyName: String, + defaultValue: Int, + fileName: String): Int = { + + val data = (poolNode \ propertyName).text.trim + try { + data.toInt + } catch { + case e: NumberFormatException => + logWarning(s"Error while loading fair scheduler configuration from $fileName: " + + s"$propertyName is blank or invalid: $data, using the default $propertyName: " + + s"$defaultValue for pool: $poolName") + defaultValue + } + } + override def addTaskSetManager(manager: Schedulable, properties: Properties) { - var poolName = DEFAULT_POOL_NAME - var parentPool = rootPool.getSchedulableByName(poolName) - if (properties != null) { - poolName = properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) - parentPool = rootPool.getSchedulableByName(poolName) - if (parentPool == null) { - // we will create a new pool that user has configured in app - // instead of being defined in xml file - parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, - DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) - rootPool.addSchedulable(parentPool) - logInfo("Created pool %s, schedulingMode: %s, minShare: %d, weight: %d".format( - poolName, DEFAULT_SCHEDULING_MODE, DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT)) + val poolName = if (properties != null) { + properties.getProperty(FAIR_SCHEDULER_PROPERTIES, DEFAULT_POOL_NAME) + } else { + DEFAULT_POOL_NAME } + var parentPool = rootPool.getSchedulableByName(poolName) + if (parentPool == null) { + // we will create a new pool that user has configured in app + // instead of being defined in xml file + parentPool = new Pool(poolName, DEFAULT_SCHEDULING_MODE, + DEFAULT_MINIMUM_SHARE, DEFAULT_WEIGHT) + rootPool.addSchedulable(parentPool) + logWarning(s"A job was submitted with scheduler pool $poolName, which has not been " + + "configured. This can happen when the file that pools are read from isn't set, or " + + s"when that file doesn't contain $poolName. Created $poolName with default " + + s"configuration (schedulingMode: $DEFAULT_SCHEDULING_MODE, " + + s"minShare: $DEFAULT_MINIMUM_SHARE, weight: $DEFAULT_WEIGHT)") } parentPool.addSchedulable(manager) logInfo("Added task set " + manager.name + " tasks to pool " + poolName) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 8801a761afae3..22db3350abfa7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,8 +30,21 @@ private[spark] trait SchedulerBackend { def reviveOffers(): Unit def defaultParallelism(): Int - def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = + /** + * Requests that an executor kills a running task. + * + * @param taskId Id of the task. + * @param executorId Id of the executor the task is running on. + * @param interruptThread Whether the executor should interrupt the task thread. + * @param reason The reason for the task kill. + */ + def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = throw new UnsupportedOperationException + def isReady(): Boolean = true /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 51416e5ce97fc..db4d9efa2270c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashSet + import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId @@ -47,6 +49,17 @@ private[spark] class ShuffleMapStage( private[this] var _numAvailableOutputs: Int = 0 + /** + * Partitions that either haven't yet been computed, or that were computed on an executor + * that has since been lost, so should be re-computed. This variable is used by the + * DAGScheduler to determine when a stage has completed. Task successes in both the active + * attempt for the stage or in earlier attempts for this stage can cause paritition ids to get + * removed from pendingPartitions. As a result, this variable may be inconsistent with the pending + * tasks in the TaskSetManager for the active attempt for the stage (the partitions stored here + * will always be a subset of the partitions that the TaskSetManager thinks are pending). + */ + val pendingPartitions = new HashSet[Int] + /** * List of [[MapStatus]] for each partition. The index of the array is the map partition id, * and each value in the array is the list of possible [[MapStatus]] for a partition diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 66d6790e168f2..7a25c47e2cab3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -25,7 +25,6 @@ import scala.language.existentials import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter @@ -42,8 +41,9 @@ import org.apache.spark.shuffle.ShuffleWriter * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side + * and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to @@ -56,18 +56,18 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - metrics: TaskMetrics, localProperties: Properties, + serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, appAttemptId: Option[String] = None) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties, jobId, - appId, appAttemptId) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, + serializedTaskMetrics, jobId, appId, appAttemptId) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, new Properties, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 7618dfeeedf8d..59f89a82a1da8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -87,8 +87,13 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long) - extends SparkListenerEvent +case class SparkListenerBlockManagerAdded( + time: Long, + blockManagerId: BlockManagerId, + maxMem: Long, + maxOnHeapMem: Option[Long] = None, + maxOffHeapMem: Option[Long] = None) extends SparkListenerEvent { +} @DeveloperApi case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId) @@ -105,6 +110,28 @@ case class SparkListenerExecutorAdded(time: Long, executorId: String, executorIn case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerExecutorBlacklisted( + time: Long, + executorId: String, + taskFailures: Int) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerExecutorUnblacklisted(time: Long, executorId: String) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerNodeBlacklisted( + time: Long, + hostId: String, + executorFailures: Int) + extends SparkListenerEvent + +@DeveloperApi +case class SparkListenerNodeUnblacklisted(time: Long, hostId: String) + extends SparkListenerEvent + @DeveloperApi case class SparkListenerBlockUpdated(blockUpdatedInfo: BlockUpdatedInfo) extends SparkListenerEvent @@ -133,9 +160,9 @@ case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent /** * An internal class that describes the metadata of an event log. - * This event is not meant to be posted to listeners downstream. */ -private[spark] case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent +@DeveloperApi +case class SparkListenerLogStart(sparkVersion: String) extends SparkListenerEvent /** * Interface for creating history listeners defined in other modules like SQL, which are used to @@ -238,6 +265,26 @@ private[spark] trait SparkListenerInterface { */ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit + /** + * Called when the driver blacklists an executor for a Spark application. + */ + def onExecutorBlacklisted(executorBlacklisted: SparkListenerExecutorBlacklisted): Unit + + /** + * Called when the driver re-enables a previously blacklisted executor. + */ + def onExecutorUnblacklisted(executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit + + /** + * Called when the driver blacklists a node for a Spark application. + */ + def onNodeBlacklisted(nodeBlacklisted: SparkListenerNodeBlacklisted): Unit + + /** + * Called when the driver re-enables a previously blacklisted node. + */ + def onNodeUnblacklisted(nodeUnblacklisted: SparkListenerNodeUnblacklisted): Unit + /** * Called when the driver receives a block update info. */ @@ -252,7 +299,7 @@ private[spark] trait SparkListenerInterface { /** * :: DeveloperApi :: - * A default implementation for [[SparkListenerInterface]] that has no-op implementations for + * A default implementation for `SparkListenerInterface` that has no-op implementations for * all callbacks. * * Note that this is an internal interface which might change in different Spark releases. @@ -293,6 +340,18 @@ abstract class SparkListener extends SparkListenerInterface { override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { } + override def onExecutorBlacklisted( + executorBlacklisted: SparkListenerExecutorBlacklisted): Unit = { } + + override def onExecutorUnblacklisted( + executorUnblacklisted: SparkListenerExecutorUnblacklisted): Unit = { } + + override def onNodeBlacklisted( + nodeBlacklisted: SparkListenerNodeBlacklisted): Unit = { } + + override def onNodeUnblacklisted( + nodeUnblacklisted: SparkListenerNodeUnblacklisted): Unit = { } + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { } override def onOtherEvent(event: SparkListenerEvent): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 471586ac0852a..3b0d3b1b150fe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,9 +61,16 @@ private[spark] trait SparkListenerBus listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => listener.onExecutorRemoved(executorRemoved) + case executorBlacklisted: SparkListenerExecutorBlacklisted => + listener.onExecutorBlacklisted(executorBlacklisted) + case executorUnblacklisted: SparkListenerExecutorUnblacklisted => + listener.onExecutorUnblacklisted(executorUnblacklisted) + case nodeBlacklisted: SparkListenerNodeBlacklisted => + listener.onNodeBlacklisted(nodeBlacklisted) + case nodeUnblacklisted: SparkListenerNodeUnblacklisted => + listener.onNodeUnblacklisted(nodeUnblacklisted) case blockUpdated: SparkListenerBlockUpdated => listener.onBlockUpdated(blockUpdated) - case logStart: SparkListenerLogStart => // ignore event log metadata case _ => listener.onOtherEvent(event) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 2f972b064b477..290fd073caf27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet -import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -68,8 +67,6 @@ private[scheduler] abstract class Stage( /** Set of jobs that this stage belongs to. */ val jobIds = new HashSet[Int] - val pendingPartitions = new HashSet[Int] - /** The ID to use for the next new attempt for this stage. */ private var nextAttemptId: Int = 0 @@ -77,7 +74,7 @@ private[scheduler] abstract class Stage( val details: String = callSite.longForm /** - * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized + * Pointer to the [[StageInfo]] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this * StageInfo to tell SparkListeners when a job starts (which happens before any stage attempts * have been created). @@ -90,23 +87,12 @@ private[scheduler] abstract class Stage( * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - private val fetchFailedAttemptIds = new HashSet[Int] + val fetchFailedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { fetchFailedAttemptIds.clear() } - /** - * Check whether we should abort the failedStage due to multiple consecutive fetch failures. - * - * This method updates the running set of failed stage attempts and returns - * true if the number of failures exceeds the allowable number of failures. - */ - private[scheduler] def failedOnFetchAndShouldAbort(stageAttemptId: Int): Boolean = { - fetchFailedAttemptIds.add(stageAttemptId) - fetchFailedAttemptIds.size >= Stage.MAX_CONSECUTIVE_FETCH_FAILURES - } - /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ def makeNewStageAttempt( numPartitionsToCompute: Int, @@ -131,8 +117,3 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] } - -private[scheduler] object Stage { - // The number of consecutive failures allowed before a stage is aborted - val MAX_CONSECUTIVE_FETCH_FAILURES = 4 -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9385e3c31e1e4..7767ef1803a06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,18 +17,14 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties -import scala.collection.mutable -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util._ /** @@ -45,8 +41,9 @@ import org.apache.spark.util._ * @param stageId id of the stage this task belongs to * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD - * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * @param serializedTaskMetrics a `TaskMetrics` that is created and serialized on the driver side + * and sent to executor side. * * The parameters below are optional: * @param jobId id of the job this task belongs to @@ -57,13 +54,17 @@ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - // The default value is only used in tests. - val metrics: TaskMetrics = TaskMetrics.registered, @transient var localProperties: Properties = new Properties, + // The default value is only used in tests. + serializedTaskMetrics: Array[Byte] = + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, val appAttemptId: Option[String] = None) extends Serializable { + @transient lazy val metrics: TaskMetrics = + SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) + /** * Called by [[org.apache.spark.executor.Executor]] to run this task. * @@ -88,12 +89,20 @@ private[spark] abstract class Task[T]( TaskContext.setTaskContext(context) taskThread = Thread.currentThread() - if (_killed) { - kill(interruptThread = false) + if (_reasonIfKilled != null) { + kill(interruptThread = false, _reasonIfKilled) } - new CallerContext("TASK", appId, appAttemptId, jobId, Option(stageId), Option(stageAttemptId), - Option(taskAttemptId), Option(attemptNumber)).setCurrentContext() + new CallerContext( + "TASK", + SparkEnv.get.conf.get(APP_CALLER_CONTEXT), + appId, + appAttemptId, + jobId, + Option(stageId), + Option(stageAttemptId), + Option(taskAttemptId), + Option(attemptNumber)).setCurrentContext() try { runTask(context) @@ -106,24 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } @@ -138,26 +156,26 @@ private[spark] abstract class Task[T]( def preferredLocations: Seq[TaskLocation] = Nil - // Map output tracker epoch. Will be set by TaskScheduler. + // Map output tracker epoch. Will be set by TaskSetManager. var epoch: Long = -1 // Task context, to be initialized in run(). - @transient protected var context: TaskContextImpl = _ + @transient var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ - // A flag to indicate whether the task is killed. This is used in case context is not yet - // initialized when kill() is invoked. - @volatile @transient private var _killed = false + // If non-null, this task has been killed and the reason is as specified. This is used in case + // context is not yet initialized when kill() is invoked. + @volatile @transient private var _reasonIfKilled: String = null protected var _executorDeserializeTime: Long = 0 protected var _executorDeserializeCpuTime: Long = 0 /** - * Whether the task has been killed. + * If defined, this task has been killed and this option contains the reason. */ - def killed: Boolean = _killed + def reasonIfKilled: Option[String] = Option(_reasonIfKilled) /** * Returns the amount of time spent deserializing the RDD and function to be run. @@ -171,14 +189,11 @@ private[spark] abstract class Task[T]( */ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.internalAccums.filter { a => - // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its - // value will be updated at driver side. - // Note: internal accumulators representing task metrics always count failed values - !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) - // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter - // them out. - } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + // Note: internal accumulators representing task metrics always count failed values + context.taskMetrics.nonZeroInternalAccums() ++ + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not + // filter them out. + context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } @@ -190,99 +205,14 @@ private[spark] abstract class Task[T]( * be called multiple times. * If interruptThread is true, we will also call Thread.interrupt() on the Task's executor thread. */ - def kill(interruptThread: Boolean) { - _killed = true + def kill(interruptThread: Boolean, reason: String) { + require(reason != null) + _reasonIfKilled = reason if (context != null) { - context.markInterrupted() + context.markInterrupted(reason) } if (interruptThread && taskThread != null) { taskThread.interrupt() } } } - -/** - * Handles transmission of tasks and their dependencies, because this can be slightly tricky. We - * need to send the list of JARs and files added to the SparkContext with each task to ensure that - * worker nodes find out about it, but we can't make it part of the Task because the user's code in - * the task might depend on one of the JARs. Thus we serialize each task as multiple objects, by - * first writing out its dependencies. - */ -private[spark] object Task { - /** - * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) - */ - def serializeWithDependencies( - task: Task[_], - currentFiles: mutable.Map[String, Long], - currentJars: mutable.Map[String, Long], - serializer: SerializerInstance) - : ByteBuffer = { - - val out = new ByteBufferOutputStream(4096) - val dataOut = new DataOutputStream(out) - - // Write currentFiles - dataOut.writeInt(currentFiles.size) - for ((name, timestamp) <- currentFiles) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write currentJars - dataOut.writeInt(currentJars.size) - for ((name, timestamp) <- currentJars) { - dataOut.writeUTF(name) - dataOut.writeLong(timestamp) - } - - // Write the task properties separately so it is available before full task deserialization. - val propBytes = Utils.serialize(task.localProperties) - dataOut.writeInt(propBytes.length) - dataOut.write(propBytes) - - // Write the task itself and finish - dataOut.flush() - val taskBytes = serializer.serialize(task) - Utils.writeByteBuffer(taskBytes, out) - out.close() - out.toByteBuffer - } - - /** - * Deserialize the list of dependencies in a task serialized with serializeWithDependencies, - * and return the task itself as a serialized ByteBuffer. The caller can then update its - * ClassLoaders and deserialize the task. - * - * @return (taskFiles, taskJars, taskProps, taskBytes) - */ - def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { - - val in = new ByteBufferInputStream(serializedTask) - val dataIn = new DataInputStream(in) - - // Read task's files - val taskFiles = new HashMap[String, Long]() - val numFiles = dataIn.readInt() - for (i <- 0 until numFiles) { - taskFiles(dataIn.readUTF()) = dataIn.readLong() - } - - // Read task's JARs - val taskJars = new HashMap[String, Long]() - val numJars = dataIn.readInt() - for (i <- 0 until numJars) { - taskJars(dataIn.readUTF()) = dataIn.readLong() - } - - val propLength = dataIn.readInt() - val propBytes = new Array[Byte](propLength) - dataIn.readFully(propBytes, 0, propLength) - val taskProps = Utils.deserialize[Properties](propBytes) - - // Create a sub-buffer for the rest of the data, which is the serialized Task object - val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, taskProps, subBuffer) - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 1c7c81c488c3a..c98b87148e404 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -17,13 +17,32 @@ package org.apache.spark.scheduler +import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Properties -import org.apache.spark.util.SerializableBuffer +import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, Map} + +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * Description of a task that gets passed onto executors to be executed, usually created by - * [[TaskSetManager.resourceOffer]]. + * `TaskSetManager.resourceOffer`. + * + * TaskDescriptions and the associated Task need to be serialized carefully for two reasons: + * + * (1) When a TaskDescription is received by an Executor, the Executor needs to first get the + * list of JARs and files and add these to the classpath, and set the properties, before + * deserializing the Task object (serializedTask). This is why the Properties are included + * in the TaskDescription, even though they're also in the serialized task. + * (2) Because a TaskDescription is serialized and sent to an executor for each task, efficient + * serialization (both in terms of serialization time and serialized buffer size) is + * important. For this reason, we serialize TaskDescriptions ourselves with the + * TaskDescription.encode and TaskDescription.decode methods. This results in a smaller + * serialized size because it avoids serializing unnecessary fields in the Map objects + * (which can introduce significant overhead when the maps are small). */ private[spark] class TaskDescription( val taskId: Long, @@ -31,13 +50,95 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet - _serializedTask: ByteBuffer) - extends Serializable { + val addedFiles: Map[String, Long], + val addedJars: Map[String, Long], + val properties: Properties, + val serializedTask: ByteBuffer) { + + override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) +} - // Because ByteBuffers are not serializable, wrap the task in a SerializableBuffer - private val buffer = new SerializableBuffer(_serializedTask) +private[spark] object TaskDescription { + private def serializeStringLongMap(map: Map[String, Long], dataOut: DataOutputStream): Unit = { + dataOut.writeInt(map.size) + for ((key, value) <- map) { + dataOut.writeUTF(key) + dataOut.writeLong(value) + } + } - def serializedTask: ByteBuffer = buffer.value + def encode(taskDescription: TaskDescription): ByteBuffer = { + val bytesOut = new ByteBufferOutputStream(4096) + val dataOut = new DataOutputStream(bytesOut) - override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) + dataOut.writeLong(taskDescription.taskId) + dataOut.writeInt(taskDescription.attemptNumber) + dataOut.writeUTF(taskDescription.executorId) + dataOut.writeUTF(taskDescription.name) + dataOut.writeInt(taskDescription.index) + + // Write files. + serializeStringLongMap(taskDescription.addedFiles, dataOut) + + // Write jars. + serializeStringLongMap(taskDescription.addedJars, dataOut) + + // Write properties. + dataOut.writeInt(taskDescription.properties.size()) + taskDescription.properties.asScala.foreach { case (key, value) => + dataOut.writeUTF(key) + // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values + val bytes = value.getBytes(StandardCharsets.UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + + // Write the task. The task is already serialized, so write it directly to the byte buffer. + Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut) + + dataOut.close() + bytesOut.close() + bytesOut.toByteBuffer + } + + private def deserializeStringLongMap(dataIn: DataInputStream): HashMap[String, Long] = { + val map = new HashMap[String, Long]() + val mapSize = dataIn.readInt() + for (i <- 0 until mapSize) { + map(dataIn.readUTF()) = dataIn.readLong() + } + map + } + + def decode(byteBuffer: ByteBuffer): TaskDescription = { + val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer)) + val taskId = dataIn.readLong() + val attemptNumber = dataIn.readInt() + val executorId = dataIn.readUTF() + val name = dataIn.readUTF() + val index = dataIn.readInt() + + // Read files. + val taskFiles = deserializeStringLongMap(dataIn) + + // Read jars. + val taskJars = deserializeStringLongMap(dataIn) + + // Read properties. + val properties = new Properties() + val numProperties = dataIn.readInt() + for (i <- 0 until numProperties) { + val key = dataIn.readUTF() + val valueLength = dataIn.readInt() + val valueBytes = new Array[Byte](valueLength) + dataIn.readFully(valueBytes) + properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8)) + } + + // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). + val serializedTask = byteBuffer.slice() + + new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, + properties, serializedTask) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index eeb7963c9e610..9843eab4f1346 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,8 +17,6 @@ package org.apache.spark.scheduler -import scala.collection.mutable.ListBuffer - import org.apache.spark.TaskState import org.apache.spark.TaskState.TaskState import org.apache.spark.annotation.DeveloperApi @@ -54,7 +52,13 @@ class TaskInfo( * accumulable to be updated multiple times in a single task or for two accumulables with the * same name but different IDs to exist in a task. */ - val accumulables = ListBuffer[AccumulableInfo]() + def accumulables: Seq[AccumulableInfo] = _accumulables + + private[this] var _accumulables: Seq[AccumulableInfo] = Nil + + private[spark] def setAccumulables(newAccumulables: Seq[AccumulableInfo]): Unit = { + _accumulables = newAccumulables + } /** * The time when the task has completed successfully (including the time to remotely fetch @@ -66,11 +70,13 @@ class TaskInfo( var killed = false - private[spark] def markGettingResult(time: Long = System.currentTimeMillis) { + private[spark] def markGettingResult(time: Long) { gettingResultTime = time } - private[spark] def markFinished(state: TaskState, time: Long = System.currentTimeMillis) { + private[spark] def markFinished(state: TaskState, time: Long) { + // finishTime should be set larger than 0, otherwise "finished" below will return false. + assert(time > 0) finishTime = time if (state == TaskState.FAILED) { failed = true diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index b1addc128e696..a284f7956cd31 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -143,8 +143,12 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Exception => // No-op + } finally { + // If there's an error while deserializing the TaskEndReason, this Runnable + // will die. Still tell the scheduler about the task failure, to avoid a hang + // where the scheduler thinks the task is still running. + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } - scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } }) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cd13eebe74a99..3de7d1f7de22b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,6 +54,13 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit + /** + * Kills a task attempt. + * + * @return Whether the task was successfully killed. + */ + def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 3e3f1ad031e66..1b6bc9139f9c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{Timer, TimerTask} +import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -38,7 +38,7 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. - * It can also work with a local setup by using a [[LocalSchedulerBackend]] and setting + * It can also work with a local setup by using a `LocalSchedulerBackend` and setting * isLocal to true. It handles common logic, like determining a scheduling order across jobs, waking * up to launch speculative tasks, etc. * @@ -51,13 +51,29 @@ import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils} * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ -private[spark] class TaskSchedulerImpl( +private[spark] class TaskSchedulerImpl private[scheduler]( val sc: SparkContext, val maxTaskFailures: Int, + private[scheduler] val blacklistTrackerOpt: Option[BlacklistTracker], isLocal: Boolean = false) - extends TaskScheduler with Logging -{ - def this(sc: SparkContext) = this(sc, sc.conf.get(config.MAX_TASK_FAILURES)) + extends TaskScheduler with Logging { + + import TaskSchedulerImpl._ + + def this(sc: SparkContext) = { + this( + sc, + sc.conf.get(config.MAX_TASK_FAILURES), + TaskSchedulerImpl.maybeCreateBlacklistTracker(sc)) + } + + def this(sc: SparkContext, maxTaskFailures: Int, isLocal: Boolean) = { + this( + sc, + maxTaskFailures, + TaskSchedulerImpl.maybeCreateBlacklistTracker(sc), + isLocal = isLocal) + } val conf = sc.conf @@ -93,10 +109,12 @@ private[spark] class TaskSchedulerImpl( // Incrementing task IDs val nextTaskId = new AtomicLong(0) - // Number of tasks running on each executor - private val executorIdToTaskCount = new HashMap[String, Int] + // IDs of the tasks running on each executor + private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] - def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap + def runningTasksByExecutors: Map[String, Int] = synchronized { + executorIdToRunningTaskIds.toMap.mapValues(_.size) + } // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host @@ -113,16 +131,18 @@ private[spark] class TaskSchedulerImpl( val mapOutputTracker = SparkEnv.get.mapOutputTracker - var schedulableBuilder: SchedulableBuilder = null - var rootPool: Pool = null + private var schedulableBuilder: SchedulableBuilder = null // default scheduler is FIFO - private val schedulingModeConf = conf.get("spark.scheduler.mode", "FIFO") - val schedulingMode: SchedulingMode = try { - SchedulingMode.withName(schedulingModeConf.toUpperCase) - } catch { - case e: java.util.NoSuchElementException => - throw new SparkException(s"Unrecognized spark.scheduler.mode: $schedulingModeConf") - } + private val schedulingModeConf = conf.get(SCHEDULER_MODE_PROPERTY, SchedulingMode.FIFO.toString) + val schedulingMode: SchedulingMode = + try { + SchedulingMode.withName(schedulingModeConf.toUpperCase(Locale.ROOT)) + } catch { + case e: java.util.NoSuchElementException => + throw new SparkException(s"Unrecognized $SCHEDULER_MODE_PROPERTY: $schedulingModeConf") + } + + val rootPool: Pool = new Pool("", schedulingMode, 0, 0) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) @@ -133,8 +153,6 @@ private[spark] class TaskSchedulerImpl( def initialize(backend: SchedulerBackend) { this.backend = backend - // temporarily set rootPool name to empty - rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { schedulingMode match { case SchedulingMode.FIFO => @@ -142,7 +160,8 @@ private[spark] class TaskSchedulerImpl( case SchedulingMode.FAIR => new FairSchedulableBuilder(rootPool, conf) case _ => - throw new IllegalArgumentException(s"Unsupported spark.scheduler.mode: $schedulingMode") + throw new IllegalArgumentException(s"Unsupported $SCHEDULER_MODE_PROPERTY: " + + s"$schedulingMode") } } schedulableBuilder.buildPools() @@ -155,7 +174,7 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - speculationScheduler.scheduleAtFixedRate(new Runnable { + speculationScheduler.scheduleWithFixedDelay(new Runnable { override def run(): Unit = Utils.tryOrStopSparkContext(sc) { checkSpeculatableTasks() } @@ -207,7 +226,7 @@ private[spark] class TaskSchedulerImpl( private[scheduler] def createTaskSetManager( taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = { - new TaskSetManager(this, taskSet, maxTaskFailures) + new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt) } override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { @@ -222,7 +241,7 @@ private[spark] class TaskSchedulerImpl( // simply abort the stage. tsm.runningTasksSet.foreach { tid => val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId, interruptThread) + backend.killTask(tid, execId, interruptThread, reason = "stage cancelled") } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) @@ -230,6 +249,18 @@ private[spark] class TaskSchedulerImpl( } } + override def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + logInfo(s"Killing task $taskId: $reason") + val execId = taskIdToExecutorId.get(taskId) + if (execId.isDefined) { + backend.killTask(taskId, execId.get, interruptThread, reason) + true + } else { + logWarning(s"Could not kill task $taskId because no task with that ID was found.") + false + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -254,6 +285,8 @@ private[spark] class TaskSchedulerImpl( availableCpus: Array[Int], tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { var launchedTask = false + // nodes and executors that are blacklisted for the entire application have already been + // filtered out by this point for (i <- 0 until shuffledOffers.size) { val execId = shuffledOffers(i).executorId val host = shuffledOffers(i).host @@ -264,7 +297,7 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetManager(tid) = taskSet taskIdToExecutorId(tid) = execId - executorIdToTaskCount(execId) += 1 + executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) launchedTask = true @@ -294,11 +327,11 @@ private[spark] class TaskSchedulerImpl( if (!hostToExecutors.contains(o.host)) { hostToExecutors(o.host) = new HashSet[String]() } - if (!executorIdToTaskCount.contains(o.executorId)) { + if (!executorIdToRunningTaskIds.contains(o.executorId)) { hostToExecutors(o.host) += o.executorId executorAdded(o.executorId, o.host) executorIdToHost(o.executorId) = o.host - executorIdToTaskCount(o.executorId) = 0 + executorIdToRunningTaskIds(o.executorId) = HashSet[Long]() newExecAvail = true } for (rack <- getRackForHost(o.host)) { @@ -306,8 +339,19 @@ private[spark] class TaskSchedulerImpl( } } - // Randomly shuffle offers to avoid always placing tasks on the same set of workers. - val shuffledOffers = Random.shuffle(offers) + // Before making any offers, remove any nodes from the blacklist whose blacklist has expired. Do + // this here to avoid a separate thread and added synchronization overhead, and also because + // updating the blacklist is only relevant when task offers are being made. + blacklistTrackerOpt.foreach(_.applyBlacklistTimeout()) + + val filteredOffers = blacklistTrackerOpt.map { blacklistTracker => + offers.filter { offer => + !blacklistTracker.isNodeBlacklisted(offer.host) && + !blacklistTracker.isExecutorBlacklisted(offer.executorId) + } + }.getOrElse(offers) + + val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) val availableCpus = shuffledOffers.map(o => o.cores).toArray @@ -344,43 +388,47 @@ private[spark] class TaskSchedulerImpl( return tasks } + /** + * Shuffle offers around to avoid always placing tasks on the same workers. Exposed to allow + * overriding in tests, so it can be deterministic. + */ + protected def shuffleOffers(offers: IndexedSeq[WorkerOffer]): IndexedSeq[WorkerOffer] = { + Random.shuffle(offers) + } + def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { var failedExecutor: Option[String] = None var reason: Option[ExecutorLossReason] = None synchronized { try { - if (state == TaskState.LOST && taskIdToExecutorId.contains(tid)) { - // We lost this entire executor, so remember that it's gone - val execId = taskIdToExecutorId(tid) - - if (executorIdToTaskCount.contains(execId)) { - reason = Some( - SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) - removeExecutor(execId, reason.get) - failedExecutor = Some(execId) - } - } taskIdToTaskSetManager.get(tid) match { case Some(taskSet) => - if (TaskState.isFinished(state)) { - taskIdToTaskSetManager.remove(tid) - taskIdToExecutorId.remove(tid).foreach { execId => - if (executorIdToTaskCount.contains(execId)) { - executorIdToTaskCount(execId) -= 1 - } + if (state == TaskState.LOST) { + // TaskState.LOST is only used by the deprecated Mesos fine-grained scheduling mode, + // where each executor corresponds to a single task, so mark the executor as failed. + val execId = taskIdToExecutorId.getOrElse(tid, throw new IllegalStateException( + "taskIdToTaskSetManager.contains(tid) <=> taskIdToExecutorId.contains(tid)")) + if (executorIdToRunningTaskIds.contains(execId)) { + reason = Some( + SlaveLost(s"Task $tid was lost, so marking the executor as lost as well.")) + removeExecutor(execId, reason.get) + failedExecutor = Some(execId) } } - if (state == TaskState.FINISHED) { - taskSet.removeRunningTask(tid) - taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) - } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + if (TaskState.isFinished(state)) { + cleanupTaskState(tid) taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + if (state == TaskState.FINISHED) { + taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) + } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { + taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + } } case None => logError( ("Ignoring update with state %s for TID %s because its task set is gone (this is " + - "likely the result of receiving duplicate task finished status updates)") + "likely the result of receiving duplicate task finished status updates) or its " + + "executor has been marked as failed.") .format(state, tid)) } } catch { @@ -433,7 +481,7 @@ private[spark] class TaskSchedulerImpl( taskState: TaskState, reason: TaskFailedReason): Unit = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { + if (!taskSetManager.isZombie && !taskSetManager.someAttemptSucceeded(tid)) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. backend.reviveOffers() @@ -491,7 +539,7 @@ private[spark] class TaskSchedulerImpl( var failedExecutor: Option[String] = None synchronized { - if (executorIdToTaskCount.contains(executorId)) { + if (executorIdToRunningTaskIds.contains(executorId)) { val hostPort = executorIdToHost(executorId) logExecutorLoss(executorId, hostPort, reason) removeExecutor(executorId, reason) @@ -533,13 +581,31 @@ private[spark] class TaskSchedulerImpl( logError(s"Lost executor $executorId on $hostPort: $reason") } + /** + * Cleans up the TaskScheduler's state for tracking the given task. + */ + private def cleanupTaskState(tid: Long): Unit = { + taskIdToTaskSetManager.remove(tid) + taskIdToExecutorId.remove(tid).foreach { executorId => + executorIdToRunningTaskIds.get(executorId).foreach { _.remove(tid) } + } + } + /** * Remove an executor from all our data structures and mark it as lost. If the executor's loss * reason is not yet known, do not yet remove its association with its host nor update the status * of any running tasks, since the loss reason defines whether we'll fail those tasks. */ private def removeExecutor(executorId: String, reason: ExecutorLossReason) { - executorIdToTaskCount -= executorId + // The tasks on the lost executor may not send any more status updates (because the executor + // has been lost), so they should be cleaned up here. + executorIdToRunningTaskIds.remove(executorId).foreach { taskIds => + logDebug("Cleaning up TaskScheduler state for tasks " + + s"${taskIds.mkString("[", ",", "]")} on failed executor $executorId") + // We do not notify the TaskSetManager of the task failures because that will + // happen below in the rootPool.executorLost() call. + taskIds.foreach(cleanupTaskState) + } val host = executorIdToHost(executorId) val execs = hostToExecutors.getOrElse(host, new HashSet) @@ -558,6 +624,7 @@ private[spark] class TaskSchedulerImpl( executorIdToHost -= executorId rootPool.executorLost(executorId, host, reason) } + blacklistTrackerOpt.foreach(_.handleRemovedExecutor(executorId)) } def executorAdded(execId: String, host: String) { @@ -577,11 +644,19 @@ private[spark] class TaskSchedulerImpl( } def isExecutorAlive(execId: String): Boolean = synchronized { - executorIdToTaskCount.contains(execId) + executorIdToRunningTaskIds.contains(execId) } def isExecutorBusy(execId: String): Boolean = synchronized { - executorIdToTaskCount.getOrElse(execId, -1) > 0 + executorIdToRunningTaskIds.get(execId).exists(_.nonEmpty) + } + + /** + * Get a snapshot of the currently blacklisted nodes for the entire application. This is + * thread-safe -- it can be called without a lock on the TaskScheduler. + */ + def nodeBlacklist(): scala.collection.immutable.Set[String] = { + blacklistTrackerOpt.map(_.nodeBlacklist()).getOrElse(scala.collection.immutable.Set()) } // By default, rack is unknown @@ -622,16 +697,19 @@ private[spark] class TaskSchedulerImpl( private[spark] object TaskSchedulerImpl { + + val SCHEDULER_MODE_PROPERTY = "spark.scheduler.mode" + /** * Used to balance containers across hosts. * * Accepts a map of hosts to resource offers for that host, and returns a prioritized list of - * resource offers representing the order in which the offers should be used. The resource + * resource offers representing the order in which the offers should be used. The resource * offers are ordered such that we'll allocate one container on each host before allocating a * second container on any host, and so on, in order to reduce the damage if a host fails. * - * For example, given , , , returns - * [o1, o5, o4, 02, o6, o3] + * For example, given {@literal }, {@literal } and + * {@literal }, returns {@literal [o1, o5, o4, o2, o6, o3]}. */ def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = { val _keyList = new ArrayBuffer[K](map.size) @@ -662,4 +740,17 @@ private[spark] object TaskSchedulerImpl { retval.toList } + + private def maybeCreateBlacklistTracker(sc: SparkContext): Option[BlacklistTracker] = { + if (BlacklistTracker.isBlacklistEnabled(sc.conf)) { + val executorAllocClient: Option[ExecutorAllocationClient] = sc.schedulerBackend match { + case b: ExecutorAllocationClient => Some(b) + case _ => None + } + Some(new BlacklistTracker(sc, executorAllocClient)) + } else { + None + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index f4b0f55b7686a..e815b7e0cf6c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -28,6 +28,10 @@ import org.apache.spark.util.Clock * (task, executor) / (task, nodes) pairs, and also completely blacklisting executors and nodes * for the entire taskset. * + * It also must store sufficient information in task failures for application level blacklisting, + * which is handled by [[BlacklistTracker]]. Note that BlacklistTracker does not know anything + * about task failures until a taskset completes successfully. + * * THREADING: This class is a helper to [[TaskSetManager]]; as with the methods in * [[TaskSetManager]] this class is designed only to be called from code with a lock on the * TaskScheduler (e.g. its event handlers). It should not be called from other threads. @@ -41,7 +45,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private val MAX_FAILED_EXEC_PER_NODE_STAGE = conf.get(config.MAX_FAILED_EXEC_PER_NODE_STAGE) /** - * A map from each executor to the task failures on that executor. + * A map from each executor to the task failures on that executor. This is used for blacklisting + * within this taskset, and it is also relayed onto [[BlacklistTracker]] for app-level + * blacklisting if this taskset completes successfully. */ val execToFailures = new HashMap[String, ExecutorFailuresInTaskSet]() @@ -57,9 +63,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, /** * Return true if this executor is blacklisted for the given task. This does *not* - * need to return true if the executor is blacklisted for the entire stage. - * That is to keep this method as fast as possible in the inner-loop of the - * scheduler, where those filters will have already been applied. + * need to return true if the executor is blacklisted for the entire stage, or blacklisted + * for the entire application. That is to keep this method as fast as possible in the inner-loop + * of the scheduler, where those filters will have already been applied. */ def isExecutorBlacklistedForTask(executorId: String, index: Int): Boolean = { execToFailures.get(executorId).exists { execFailures => @@ -72,10 +78,10 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, } /** - * Return true if this executor is blacklisted for the given stage. Completely ignores - * anything to do with the node the executor is on. That - * is to keep this method as fast as possible in the inner-loop of the scheduler, where those - * filters will already have been applied. + * Return true if this executor is blacklisted for the given stage. Completely ignores whether + * the executor is blacklisted for the entire application (or anything to do with the node the + * executor is on). That is to keep this method as fast as possible in the inner-loop of the + * scheduler, where those filters will already have been applied. */ def isExecutorBlacklistedForTaskSet(executorId: String): Boolean = { blacklistedExecs.contains(executorId) @@ -90,7 +96,7 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, exec: String, index: Int): Unit = { val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) - execFailures.updateWithFailure(index) + execFailures.updateWithFailure(index, clock.getTimeMillis()) // check if this task has also failed on other executors on the same host -- if its gone // over the limit, blacklist this task from the entire host. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index b766e4148e496..a41b059fa7dec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -19,11 +19,10 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer -import java.util.Arrays import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.math.{max, min} +import scala.math.max import scala.util.control.NonFatal import org.apache.spark._ @@ -31,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} +import org.apache.spark.util.collection.MedianHeap /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -51,6 +51,7 @@ private[spark] class TaskSetManager( sched: TaskSchedulerImpl, val taskSet: TaskSet, val maxTaskFailures: Int, + blacklistTracker: Option[BlacklistTracker] = None, clock: Clock = new SystemClock()) extends Schedulable with Logging { private val conf = sched.sc.conf @@ -62,6 +63,8 @@ private[spark] class TaskSetManager( // Limit of bytes for total size of results (default is 1GB) val maxResultSize = Utils.getMaxResultSize(conf) + val speculationEnabled = conf.getBoolean("spark.speculation", false) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -69,40 +72,46 @@ private[spark] class TaskSetManager( val tasks = taskSet.tasks val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) + + // For each task, tracks whether a copy of the task has succeeded. A task will also be + // marked as "succeeded" if it failed with a fetch failure, in which case it should not + // be re-run because the missing map data needs to be regenerated first. val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) - var tasksSuccessful = 0 + private[scheduler] var tasksSuccessful = 0 - var weight = 1 - var minShare = 0 + val weight = 1 + val minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId val name = "TaskSet_" + taskSet.id var parent: Pool = null - var totalResultSize = 0L - var calculatedTasks = 0 + private var totalResultSize = 0L + private var calculatedTasks = 0 - private val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { - if (BlacklistTracker.isBlacklistEnabled(conf)) { - Some(new TaskSetBlacklist(conf, stageId, clock)) - } else { - None + private[scheduler] val taskSetBlacklistHelperOpt: Option[TaskSetBlacklist] = { + blacklistTracker.map { _ => + new TaskSetBlacklist(conf, stageId, clock) } } - val runningTasksSet = new HashSet[Long] + private[scheduler] val runningTasksSet = new HashSet[Long] override def runningTasks: Int = runningTasksSet.size + def someAttemptSucceeded(tid: Long): Boolean = { + successful(taskInfos(tid).index) + } + // True once no more tasks should be launched for this task set manager. TaskSetManagers enter // the zombie state once at least one attempt of each task has completed successfully, or if the // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie // state in order to continue to track and account for the running tasks. // TODO: We should kill any running task attempts when the task set manager becomes a zombie. - var isZombie = false + private[scheduler] var isZombie = false // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the @@ -126,17 +135,22 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - var pendingTasksWithNoPrefs = new ArrayBuffer[Int] + private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). - val allPendingTasks = new ArrayBuffer[Int] + private val allPendingTasks = new ArrayBuffer[Int] // Tasks that can be speculated. Since these will be a small fraction of total // tasks, we'll just hold them in a HashSet. - val speculatableTasks = new HashSet[Int] + private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - val taskInfos = new HashMap[Long, TaskInfo] + private val taskInfos = new HashMap[Long, TaskInfo] + + // Use a MedianHeap to record durations of successful tasks so we know when to launch + // speculative tasks. This is only used when speculation is enabled, to avoid the overhead + // of inserting into the heap when the heap won't be used. + val successfulTaskDurations = new MedianHeap() // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = @@ -145,7 +159,7 @@ private[spark] class TaskSetManager( // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception // was printed. This should ideally be an LRU map that can drop old exceptions automatically. - val recentExceptions = HashMap[String, (Int, Long)]() + private val recentExceptions = HashMap[String, (Int, Long)]() // Figure out the current map output tracker epoch and set it on all tasks val epoch = sched.mapOutputTracker.getEpoch @@ -160,21 +174,28 @@ private[spark] class TaskSetManager( addPendingTask(i) } - // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - var myLocalityLevels = computeValidLocalityLevels() - var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + /** + * Track the set of locality levels which are valid given the tasks locality preferences and + * the set of currently available executors. This is updated as executors are added and removed. + * This allows a performance optimization, of skipping levels that aren't relevant (eg., skip + * PROCESS_LOCAL if no tasks could be run PROCESS_LOCAL for the current set of executors). + */ + private[scheduler] var myLocalityLevels = computeValidLocalityLevels() + + // Time to wait at each level + private[scheduler] var localityWaits = myLocalityLevels.map(getLocalityWait) // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. // We then move down if we manage to launch a "more local" task. - var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels - var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level + private var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels + private var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE - var emittedTaskSizeWarning = false + private[scheduler] var emittedTaskSizeWarning = false /** Add a task to all the pending-task lists that it should be on. */ private def addPendingTask(index: Int) { @@ -447,9 +468,8 @@ private[spark] class TaskSetManager( lastLaunchTime = curTime } // Serialize and return the task - val startTime = clock.getTimeMillis() val serializedTask: ByteBuffer = try { - Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser) + ser.serialize(task) } catch { // If the task cannot be serialized, then there's no point to re-attempt the task, // as it will always fail. So just abort the whole task-set. @@ -476,8 +496,16 @@ private[spark] class TaskSetManager( s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") sched.dagScheduler.taskStarted(task, info) - new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, - taskName, index, serializedTask) + new TaskDescription( + taskId, + attemptNum, + execId, + taskName, + index, + sched.sc.addedFiles, + sched.sc.addedJars, + task.localProperties, + serializedTask) } } else { None @@ -487,6 +515,12 @@ private[spark] class TaskSetManager( private def maybeFinishTaskSet() { if (isZombie && runningTasks == 0) { sched.taskSetFinished(this) + if (tasksSuccessful == numTasks) { + blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet( + taskSet.stageId, + taskSet.stageAttemptId, + taskSetBlacklistHelperOpt.get.execToFailures)) + } } } @@ -589,6 +623,7 @@ private[spark] class TaskSetManager( private[scheduler] def abortIfCompletelyBlacklisted( hostToExecutors: HashMap[String, HashSet[String]]): Unit = { taskSetBlacklistHelperOpt.foreach { taskSetBlacklist => + val appBlacklist = blacklistTracker.get // Only look for unschedulable tasks when at least one executor has registered. Otherwise, // task sets will be (unnecessarily) aborted in cases when no executors have registered yet. if (hostToExecutors.nonEmpty) { @@ -615,13 +650,15 @@ private[spark] class TaskSetManager( val blacklistedEverywhere = hostToExecutors.forall { case (host, execsOnHost) => // Check if the task can run on the node val nodeBlacklisted = - taskSetBlacklist.isNodeBlacklistedForTaskSet(host) || - taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet) + appBlacklist.isNodeBlacklisted(host) || + taskSetBlacklist.isNodeBlacklistedForTaskSet(host) || + taskSetBlacklist.isNodeBlacklistedForTask(host, indexInTaskSet) if (nodeBlacklisted) { true } else { // Check if the task can run on any of the executors execsOnHost.forall { exec => + appBlacklist.isExecutorBlacklisted(exec) || taskSetBlacklist.isExecutorBlacklistedForTaskSet(exec) || taskSetBlacklist.isExecutorBlacklistedForTask(exec, indexInTaskSet) } @@ -643,7 +680,7 @@ private[spark] class TaskSetManager( */ def handleTaskGettingResult(tid: Long): Unit = { val info = taskInfos(tid) - info.markGettingResult() + info.markGettingResult(clock.getTimeMillis()) sched.dagScheduler.taskGettingResult(info) } @@ -671,22 +708,23 @@ private[spark] class TaskSetManager( def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) val index = info.index - info.markFinished(TaskState.FINISHED) + info.markFinished(TaskState.FINISHED, clock.getTimeMillis()) + if (speculationEnabled) { + successfulTaskDurations.insert(info.duration) + } removeRunningTask(tid) - // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the - // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not - // "deserialize" the value when holding a lock to avoid blocking other threads. So we call - // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. - // Note: "result.value()" only deserializes the value when it's called at the first time, so - // here "result.value()" just returns the value and won't block other threads. - sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) + // Kill any other attempts for the same task (since those are unnecessary now that one // attempt completed successfully). for (attemptInfo <- taskAttempts(index) if attemptInfo.running) { logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") - sched.backend.killTask(attemptInfo.taskId, attemptInfo.executorId, true) + sched.backend.killTask( + attemptInfo.taskId, + attemptInfo.executorId, + interruptThread = true, + reason = "another attempt succeeded") } if (!successful(index)) { tasksSuccessful += 1 @@ -702,6 +740,13 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the + // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not + // "deserialize" the value when holding a lock to avoid blocking other threads. So we call + // "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here. + // Note: "result.value()" only deserializes the value when it's called at the first time, so + // here "result.value()" just returns the value and won't block other threads. + sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates, info) maybeFinishTaskSet() } @@ -715,7 +760,7 @@ private[spark] class TaskSetManager( return } removeRunningTask(tid) - info.markFinished(state) + info.markFinished(state, clock.getTimeMillis()) val index = info.index copiesRunning(index) -= 1 var accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty @@ -782,10 +827,10 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (successful(index)) { - logInfo( - s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, " + - "but another instance of the task has already succeeded, " + - "so not re-queuing the task to be re-executed.") + logInfo(s"Task ${info.id} in stage ${taskSet.id} (TID $tid) failed, but the task will not" + + s" be re-executed (either because the task failed with a shuffle data fetch failure," + + s" so the previous stage needs to be re-run, or because a different copy of the task" + + s" has already succeeded).") } else { addPendingTask(index) } @@ -850,7 +895,8 @@ private[spark] class TaskSetManager( // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled + && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -882,8 +928,6 @@ private[spark] class TaskSetManager( * Check for tasks to be speculated and return true if there are any. This is called periodically * by the TaskScheduler. * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a @@ -894,16 +938,16 @@ private[spark] class TaskSetManager( var foundTasks = false val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) + if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { val time = clock.getTimeMillis() - val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray - Arrays.sort(durations) - val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.length - 1)) + var medianDuration = successfulTaskDurations.median val threshold = max(SPECULATION_MULTIPLIER * medianDuration, minTimeToSpeculation) // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { + for (tid <- runningTasksSet) { + val info = taskInfos(tid) val index = info.index if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && !speculatableTasks.contains(index)) { @@ -942,18 +986,18 @@ private[spark] class TaskSetManager( private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && + if (!pendingTasksForExecutor.isEmpty && pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { levels += PROCESS_LOCAL } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 && + if (!pendingTasksForHost.isEmpty && pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } if (!pendingTasksWithNoPrefs.isEmpty) { levels += NO_PREF } - if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 && + if (!pendingTasksForRack.isEmpty && pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) { levels += RACK_LOCAL } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index edc8aac5d1515..6b49bd699a13a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -28,14 +28,22 @@ private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable private[spark] object CoarseGrainedClusterMessages { - case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case object RetrieveSparkAppConfig extends CoarseGrainedClusterMessage + + case class SparkAppConfig( + sparkProperties: Seq[(String, String)], + ioEncryptionKey: Option[Array[Byte]]) + extends CoarseGrainedClusterMessage case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage - case class KillTask(taskId: Long, executor: String, interruptThread: Boolean) + case class KillTask(taskId: Long, executor: String, interruptThread: Boolean, reason: String) + extends CoarseGrainedClusterMessage + + case class KillExecutorsOnHost(host: String) extends CoarseGrainedClusterMessage sealed trait RegisterExecutorResponse @@ -94,7 +102,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RequestExecutors( requestedTotal: Int, localityAwareTasks: Int, - hostToLocalTaskCount: Map[String, Int]) + hostToLocalTaskCount: Map[String, Int], + nodeBlacklist: Set[String]) extends CoarseGrainedClusterMessage // Check if an executor was force-killed but for a reason unrelated to the running tasks. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 10d55c87fb8de..dc82bb7704727 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -98,11 +102,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. protected val executorsPendingLossReason = new HashSet[String] - // If this DriverEndpoint is changed to support multiple threads, - // then this may need to be changed so that we don't share the serializer - // instance across threads - private val ser = SparkEnv.get.closureSerializer.newInstance() - protected val addressToExecutorId = new HashMap[RpcAddress, String] private val reviveThread = @@ -137,14 +136,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case ReviveOffers => makeOffers() - case KillTask(taskId, executorId, interruptThread) => + case KillTask(taskId, executorId, interruptThread, reason) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorEndpoint.send(KillTask(taskId, executorId, interruptThread)) + executorInfo.executorEndpoint.send( + KillTask(taskId, executorId, interruptThread, reason)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } + + case KillExecutorsOnHost(host) => + scheduler.getExecutorsAliveOnHost(host).foreach { exec => + killExecutors(exec.toSeq, replace = true, force = true) + } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -153,6 +158,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) + } else if (scheduler.nodeBlacklist != null && + scheduler.nodeBlacklist.contains(hostname)) { + // If the cluster manager gives us an executor on a blacklisted node (because it + // already started allocating those resources before we informed it of our blacklist, + // or if it ignored our blacklist), then we reject that executor immediately. + logInfo(s"Rejecting $executorId as it has been blacklisted.") + executorRef.send(RegisterExecutorFailed(s"Executor is blacklisted: $executorId")) + context.reply(true) } else { // If the executor's rpc env is not listening for incoming connections, `hostPort` // will be null, and the client connection should be used to contact the executor. @@ -206,18 +219,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp removeExecutor(executorId, reason) context.reply(true) - case RetrieveSparkProps => - context.reply(sparkProperties) + case RetrieveSparkAppConfig => + val reply = SparkAppConfig(sparkProperties, + SparkEnv.get.securityManager.getIOEncryptionKey()) + context.reply(reply) } // Make fake resource offers on all executors private def makeOffers() { - // Filter out executors under killing - val activeExecutors = executorDataMap.filterKeys(executorIsAlive) - val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toIndexedSeq - launchTasks(scheduler.resourceOffers(workOffers)) + // Make sure no executor is killed while some task is launching on it + val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(executorIsAlive) + val workOffers = activeExecutors.map { case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + }.toIndexedSeq + scheduler.resourceOffers(workOffers) + } + if (!taskDescs.isEmpty) { + launchTasks(taskDescs) + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -230,12 +251,20 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - // Filter out executors under killing - if (executorIsAlive(executorId)) { - val executorData = executorDataMap(executorId) - val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) - launchTasks(scheduler.resourceOffers(workOffers)) + // Make sure no executor is killed while some task is launching on it + val taskDescs = CoarseGrainedSchedulerBackend.this.synchronized { + // Filter out executors under killing + if (executorIsAlive(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = IndexedSeq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + scheduler.resourceOffers(workOffers) + } else { + Seq.empty + } + } + if (!taskDescs.isEmpty) { + launchTasks(taskDescs) } } @@ -247,7 +276,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Launch tasks returned by a set of resource offers private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { for (task <- tasks.flatten) { - val serializedTask = ser.serialize(task) + val serializedTask = TaskDescription.encode(task) if (serializedTask.limit >= maxRpcMessageSize) { scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => try { @@ -362,7 +391,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp try { if (driverEndpoint != null) { logInfo("Shutting down all executors") - driverEndpoint.askWithRetry[Boolean](StopExecutors) + driverEndpoint.askSync[Boolean](StopExecutors) } } catch { case e: Exception => @@ -374,7 +403,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp stopExecutors() try { if (driverEndpoint != null) { - driverEndpoint.askWithRetry[Boolean](StopDriver) + driverEndpoint.askSync[Boolean](StopDriver) } } catch { case e: Exception => @@ -388,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -404,8 +434,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(ReviveOffers) } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverEndpoint.send(KillTask(taskId, executorId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + driverEndpoint.send(KillTask(taskId, executorId, interruptThread, reason)) } override def defaultParallelism(): Int = { @@ -461,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -498,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -525,15 +566,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp protected def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = Future.successful(false) - /** - * Request that the cluster manager kill the specified executors. - * @return whether the kill request is acknowledged. If list to kill is empty, it will return - * false. - */ - final override def killExecutors(executorIds: Seq[String]): Seq[String] = { - killExecutors(executorIds, replace = false, force = false) - } - /** * Request that the cluster manager kill the specified executors. * @@ -542,12 +574,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * replacement is being requested, then the tasks will not count towards the limit. * * @param executorIds identifiers of executors to kill - * @param replace whether to replace the killed executors with new ones - * @param force whether to force kill busy executors - * @return whether the kill request is acknowledged. If list to kill is empty, it will return - * false. + * @param replace whether to replace the killed executors with new ones, default false + * @param force whether to force kill busy executors, default false + * @return the ids of the executors acknowledged by the cluster manager to be removed. */ - final def killExecutors( + final override def killExecutors( executorIds: Seq[String], replace: Boolean, force: Boolean): Seq[String] = { @@ -573,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) @@ -603,6 +643,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ protected def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful(false) + + /** + * Request that the cluster manager kill all executors on a given host. + * @return whether the kill request is acknowledged. + */ + final override def killExecutorsOnHost(host: String): Boolean = { + logInfo(s"Requesting to kill any and all executors on host ${host}") + // A potential race exists if a new executor attempts to register on a host + // that is on the blacklist and is no no longer valid. To avoid this race, + // all executor registration and killing happens in the event loop. This way, either + // an executor will fail to register, or will be killed when all executors on a host + // are killed. + // Kill all the executors on this host in an event loop to ensure serialization. + driverEndpoint.send(KillExecutorsOnHost(host)) + true + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index 04d40e2907cff..0529fe9eed4da 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future @@ -42,7 +43,7 @@ private[spark] class StandaloneSchedulerBackend( with Logging { private var client: StandaloneAppClient = null - private var stopping = false + private val stopping = new AtomicBoolean(false) private val launcherBackend = new LauncherBackend() { override protected def onStopRequest(): Unit = stop(SparkAppHandle.State.KILLED) } @@ -93,7 +94,7 @@ private[spark] class StandaloneSchedulerBackend( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts) - val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("") + val webUrl = sc.ui.map(_.webUrl).getOrElse("") val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) // If we're using dynamic allocation, set our initial executor limit to 0 for now. // ExecutorAllocationManager will send the real initial limit to the Master later. @@ -103,8 +104,8 @@ private[spark] class StandaloneSchedulerBackend( } else { None } - val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) + val appDesc = ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, + webUrl, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor, initialExecutorLimit) client = new StandaloneAppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() launcherBackend.setState(SparkAppHandle.State.SUBMITTED) @@ -112,7 +113,7 @@ private[spark] class StandaloneSchedulerBackend( launcherBackend.setState(SparkAppHandle.State.RUNNING) } - override def stop(): Unit = synchronized { + override def stop(): Unit = { stop(SparkAppHandle.State.FINISHED) } @@ -125,21 +126,21 @@ private[spark] class StandaloneSchedulerBackend( override def disconnected() { notifyContext() - if (!stopping) { + if (!stopping.get) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { notifyContext() - if (!stopping) { + if (!stopping.get) { launcherBackend.setState(SparkAppHandle.State.KILLED) logError("Application has been killed. Reason: " + reason) try { scheduler.error(reason) } finally { // Ensure the application terminates, as we can no longer run jobs. - sc.stop() + sc.stopInNewThread() } } } @@ -206,20 +207,20 @@ private[spark] class StandaloneSchedulerBackend( registrationBarrier.release() } - private def stop(finalState: SparkAppHandle.State): Unit = synchronized { - try { - stopping = true - - super.stop() - client.stop() + private def stop(finalState: SparkAppHandle.State): Unit = { + if (stopping.compareAndSet(false, true)) { + try { + super.stop() + client.stop() - val callback = shutdownCallback - if (callback != null) { - callback(this) + val callback = shutdownCallback + if (callback != null) { + callback(this) + } + } finally { + launcherBackend.setState(finalState) + launcherBackend.close() } - } finally { - launcherBackend.setState(finalState) - launcherBackend.close() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 7a73e8ed8a38f..35509bc2f85b9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -34,7 +34,7 @@ private case class ReviveOffers() private case class StatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) -private case class KillTask(taskId: Long, interruptThread: Boolean) +private case class KillTask(taskId: Long, interruptThread: Boolean, reason: String) private case class StopExecutor() @@ -70,8 +70,8 @@ private[spark] class LocalEndpoint( reviveOffers() } - case KillTask(taskId, interruptThread) => - executor.killTask(taskId, interruptThread) + case KillTask(taskId, interruptThread, reason) => + executor.killTask(taskId, interruptThread, reason) } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -84,8 +84,7 @@ private[spark] class LocalEndpoint( val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK - executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber, - task.name, task.serializedTask) + executor.launchTask(executorBackend, task) } } } @@ -144,8 +143,9 @@ private[spark] class LocalSchedulerBackend( override def defaultParallelism(): Int = scheduler.conf.getInt("spark.default.parallelism", totalCores) - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localEndpoint.send(KillTask(taskId, interruptThread)) + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String) { + localEndpoint.send(KillTask(taskId, interruptThread, reason)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/package.scala b/core/src/main/scala/org/apache/spark/scheduler/package.scala index f0dbfc2ac5f48..4847c41710b2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/package.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/package.scala @@ -18,7 +18,7 @@ package org.apache.spark /** - * Spark's scheduling components. This includes the [[org.apache.spark.scheduler.DAGScheduler]] and - * lower level [[org.apache.spark.scheduler.TaskScheduler]]. + * Spark's scheduling components. This includes the `org.apache.spark.scheduler.DAGScheduler` and + * lower level `org.apache.spark.scheduler.TaskScheduler`. */ package object scheduler diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala index 8f15f50bee814..78dabb42ac9d2 100644 --- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala +++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala @@ -16,79 +16,108 @@ */ package org.apache.spark.security -import java.io.{InputStream, OutputStream} +import java.io.{EOFException, InputStream, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{ReadableByteChannel, WritableByteChannel} import java.util.Properties +import javax.crypto.KeyGenerator import javax.crypto.spec.{IvParameterSpec, SecretKeySpec} +import scala.collection.JavaConverters._ + +import com.google.common.io.ByteStreams import org.apache.commons.crypto.random._ import org.apache.commons.crypto.stream._ -import org.apache.hadoop.io.Text import org.apache.spark.SparkConf -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.network.util.{CryptoUtils, JavaUtils} /** * A util class for manipulating IO encryption and decryption streams. */ private[spark] object CryptoStreamUtils extends Logging { - /** - * Constants and variables for spark IO encryption - */ - val SPARK_IO_TOKEN = new Text("SPARK_IO_TOKEN") // The initialization vector length in bytes. val IV_LENGTH_IN_BYTES = 16 // The prefix of IO encryption related configurations in Spark configuration. val SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX = "spark.io.encryption.commons.config." - // The prefix for the configurations passing to Apache Commons Crypto library. - val COMMONS_CRYPTO_CONF_PREFIX = "commons.crypto." /** - * Helper method to wrap [[OutputStream]] with [[CryptoOutputStream]] for encryption. + * Helper method to wrap `OutputStream` with `CryptoOutputStream` for encryption. */ def createCryptoOutputStream( os: OutputStream, - sparkConf: SparkConf): OutputStream = { - val properties = toCryptoConf(sparkConf) - val iv = createInitializationVector(properties) + sparkConf: SparkConf, + key: Array[Byte]): OutputStream = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) os.write(iv) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoOutputStream(transformationStr, properties, os, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + new CryptoOutputStream(params.transformation, params.conf, os, params.keySpec, + new IvParameterSpec(iv)) } /** - * Helper method to wrap [[InputStream]] with [[CryptoInputStream]] for decryption. + * Wrap a `WritableByteChannel` for encryption. + */ + def createWritableChannel( + channel: WritableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): WritableByteChannel = { + val params = new CryptoParams(key, sparkConf) + val iv = createInitializationVector(params.conf) + val helper = new CryptoHelperChannel(channel) + + helper.write(ByteBuffer.wrap(iv)) + new CryptoOutputStream(params.transformation, params.conf, helper, params.keySpec, + new IvParameterSpec(iv)) + } + + /** + * Helper method to wrap `InputStream` with `CryptoInputStream` for decryption. */ def createCryptoInputStream( is: InputStream, - sparkConf: SparkConf): InputStream = { - val properties = toCryptoConf(sparkConf) + sparkConf: SparkConf, + key: Array[Byte]): InputStream = { val iv = new Array[Byte](IV_LENGTH_IN_BYTES) - is.read(iv, 0, iv.length) - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - val key = credentials.getSecretKey(SPARK_IO_TOKEN) - val transformationStr = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) - new CryptoInputStream(transformationStr, properties, is, - new SecretKeySpec(key, "AES"), new IvParameterSpec(iv)) + ByteStreams.readFully(is, iv) + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, is, params.keySpec, + new IvParameterSpec(iv)) } /** - * Get Commons-crypto configurations from Spark configurations identified by prefix. + * Wrap a `ReadableByteChannel` for decryption. */ + def createReadableChannel( + channel: ReadableByteChannel, + sparkConf: SparkConf, + key: Array[Byte]): ReadableByteChannel = { + val iv = new Array[Byte](IV_LENGTH_IN_BYTES) + val buf = ByteBuffer.wrap(iv) + JavaUtils.readFully(channel, buf) + + val params = new CryptoParams(key, sparkConf) + new CryptoInputStream(params.transformation, params.conf, channel, params.keySpec, + new IvParameterSpec(iv)) + } + def toCryptoConf(conf: SparkConf): Properties = { - val props = new Properties() - conf.getAll.foreach { case (k, v) => - if (k.startsWith(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX)) { - props.put(COMMONS_CRYPTO_CONF_PREFIX + k.substring( - SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.length()), v) - } - } - props + CryptoUtils.toCryptoConf(SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX, + conf.getAll.toMap.asJava.entrySet()) + } + + /** + * Creates a new encryption key. + */ + def createKey(conf: SparkConf): Array[Byte] = { + val keyLen = conf.get(IO_ENCRYPTION_KEY_SIZE_BITS) + val ioKeyGenAlgorithm = conf.get(IO_ENCRYPTION_KEYGEN_ALGORITHM) + val keyGen = KeyGenerator.getInstance(ioKeyGenAlgorithm) + keyGen.init(keyLen) + keyGen.generateKey().getEncoded() } /** @@ -106,4 +135,34 @@ private[spark] object CryptoStreamUtils extends Logging { } iv } + + /** + * This class is a workaround for CRYPTO-125, that forces all bytes to be written to the + * underlying channel. Since the callers of this API are using blocking I/O, there are no + * concerns with regards to CPU usage here. + */ + private class CryptoHelperChannel(sink: WritableByteChannel) extends WritableByteChannel { + + override def write(src: ByteBuffer): Int = { + val count = src.remaining() + while (src.hasRemaining()) { + sink.write(src) + } + count + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + + } + + private class CryptoParams(key: Array[Byte], sparkConf: SparkConf) { + + val keySpec = new SecretKeySpec(key, "AES") + val transformation = sparkConf.get(IO_CRYPTO_CIPHER_TRANSFORMATION) + val conf = toCryptoConf(sparkConf) + + } + } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 8b72da2ee01b7..f60dcfddfdc20 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -131,7 +131,7 @@ private[spark] class JavaSerializerInstance( * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 0d26281fe1076..e15166d11c243 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.serializer import java.io._ import java.nio.ByteBuffer +import java.util.Locale import javax.annotation.Nullable import scala.collection.JavaConverters._ @@ -43,9 +44,10 @@ import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, S import org.apache.spark.util.collection.CompactBuffer /** - * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. + * A Spark serializer that uses the + * Kryo serialization library. * - * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * @note This serializer is not guaranteed to be wire-compatible across different versions of * Spark. It is intended to be used to serialize/de-serialize data within a single * Spark application. */ @@ -243,7 +245,8 @@ class KryoDeserializationStream( kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case e: KryoException if e.getMessage.toLowerCase.contains("buffer underflow") => + case e: KryoException + if e.getMessage.toLowerCase(Locale.ROOT).contains("buffer underflow") => throw new EOFException } } @@ -312,7 +315,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer, useUnsafe: Boole } catch { case e: KryoException if e.getMessage.startsWith("Buffer overflow") => throw new SparkException(s"Kryo serialization failed: ${e.getMessage}. To avoid this, " + - "increase spark.kryoserializer.buffer.max value.") + "increase spark.kryoserializer.buffer.max value.", e) } finally { releaseKryo(kryo) } @@ -383,9 +386,16 @@ private[serializer] object KryoSerializer { classOf[HighlyCompressedMapStatus], classOf[CompactBuffer[_]], classOf[BlockManagerId], + classOf[Array[Boolean]], classOf[Array[Byte]], classOf[Array[Short]], + classOf[Array[Int]], classOf[Array[Long]], + classOf[Array[Float]], + classOf[Array[Double]], + classOf[Array[Char]], + classOf[Array[String]], + classOf[Array[Array[String]]], classOf[BoundedPriorityQueue[_]], classOf[SparkConf] ) diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index cb95246d5b0ca..cb8b1cc077637 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -23,7 +23,6 @@ import javax.annotation.concurrent.NotThreadSafe import scala.reflect.ClassTag -import org.apache.spark.SparkEnv import org.apache.spark.annotation.{DeveloperApi, Private} import org.apache.spark.util.NextIterator @@ -40,7 +39,7 @@ import org.apache.spark.util.NextIterator * * 2. Java serialization interface. * - * Note that serializers are not required to be wire-compatible across different versions of Spark. + * @note Serializers are not required to be wire-compatible across different versions of Spark. * They are intended to be used to serialize/de-serialize data within a single Spark application. */ @DeveloperApi @@ -78,7 +77,7 @@ abstract class Serializer { * position = 0 * serOut.write(obj1) * serOut.flush() - * position = # of bytes writen to stream so far + * position = # of bytes written to stream so far * obj1Bytes = output[0:position-1] * serOut.write(obj2) * serOut.flush() @@ -126,7 +125,7 @@ abstract class SerializerInstance { * A stream for writing serialized objects. */ @DeveloperApi -abstract class SerializationStream { +abstract class SerializationStream extends Closeable { /** The most general-purpose method to write an object. */ def writeObject[T: ClassTag](t: T): SerializationStream /** Writes the object representing the key of a key-value pair. */ @@ -134,7 +133,7 @@ abstract class SerializationStream { /** Writes the object representing the value of a key-value pair. */ def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value) def flush(): Unit - def close(): Unit + override def close(): Unit def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { while (iter.hasNext) { @@ -150,14 +149,14 @@ abstract class SerializationStream { * A stream for reading serialized objects. */ @DeveloperApi -abstract class DeserializationStream { +abstract class DeserializationStream extends Closeable { /** The most general-purpose method to read an object. */ def readObject[T: ClassTag](): T /** Reads the object representing the key of a key-value pair. */ def readKey[T: ClassTag](): T = readObject[T]() /** Reads the object representing the value of a key-value pair. */ def readValue[T: ClassTag](): T = readObject[T]() - def close(): Unit + override def close(): Unit /** * Read the elements of this stream through an iterator. This can only be called once, as diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2156d576f1874..bb7ed8709ba8a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -23,7 +23,6 @@ import java.nio.ByteBuffer import scala.reflect.ClassTag import org.apache.spark.SparkConf -import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.storage._ @@ -33,7 +32,12 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * Component which configures serialization, compression and encryption for various Spark * components, including automatic selection of which [[Serializer]] to use for shuffles. */ -private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) { +private[spark] class SerializerManager( + defaultSerializer: Serializer, + conf: SparkConf, + encryptionKey: Option[Array[Byte]]) { + + def this(defaultSerializer: Serializer, conf: SparkConf) = this(defaultSerializer, conf, None) private[this] val kryoSerializer = new KryoSerializer(conf) @@ -63,9 +67,6 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar // Whether to compress shuffle output temporarily spilled to disk private[this] val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - // Whether to enable IO encryption - private[this] val enableIOEncryption = conf.get(IO_ENCRYPTION_ENABLED) - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -73,12 +74,17 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + def encryptionEnabled: Boolean = encryptionKey.isDefined + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } - def getSerializer(ct: ClassTag[_]): Serializer = { - if (canUseKryo(ct)) { + // SPARK-18617: As feature in SPARK-13990 can not be applied to Spark Streaming now. The worst + // result is streaming job based on `Receiver` mode can not run on Spark 2.x properly. It may be + // a rational choice to close `kryo auto pick` feature for streaming in the first step. + def getSerializer(ct: ClassTag[_], autoPick: Boolean): Serializer = { + if (autoPick && canUseKryo(ct)) { kryoSerializer } else { defaultSerializer @@ -124,28 +130,32 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: InputStream): InputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoInputStream(s, conf) else s + def wrapForEncryption(s: InputStream): InputStream = { + encryptionKey + .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } + .getOrElse(s) } /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: OutputStream): OutputStream = { - if (enableIOEncryption) CryptoStreamUtils.createCryptoOutputStream(s, conf) else s + def wrapForEncryption(s: OutputStream): OutputStream = { + encryptionKey + .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } + .getOrElse(s) } /** * Wrap an output stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { + def wrapForCompression(blockId: BlockId, s: OutputStream): OutputStream = { if (shouldCompress(blockId)) compressionCodec.compressedOutputStream(s) else s } /** * Wrap an input stream for compression if block compression is enabled for its block type */ - private[this] def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { + def wrapForCompression(blockId: BlockId, s: InputStream): InputStream = { if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s } @@ -155,12 +165,15 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar outputStream: OutputStream, values: Iterator[T]): Unit = { val byteStream = new BufferedOutputStream(outputStream) - val ser = getSerializer(implicitly[ClassTag[T]]).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = getSerializer(implicitly[ClassTag[T]], autoPick).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } /** Serializes into a chunked byte buffer. */ - def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { + def dataSerialize[T: ClassTag]( + blockId: BlockId, + values: Iterator[T]): ChunkedByteBuffer = { dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]]) } @@ -171,8 +184,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar classTag: ClassTag[_]): ChunkedByteBuffer = { val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) val byteStream = new BufferedOutputStream(bbos) - val ser = getSerializer(classTag).newInstance() - ser.serializeStream(wrapStream(blockId, byteStream)).writeAll(values).close() + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = getSerializer(classTag, autoPick).newInstance() + ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() bbos.toChunkedByteBuffer } @@ -185,9 +199,10 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar inputStream: InputStream) (classTag: ClassTag[T]): Iterator[T] = { val stream = new BufferedInputStream(inputStream) - getSerializer(classTag) + val autoPick = !blockId.isInstanceOf[StreamBlockId] + getSerializer(classTag, autoPick) .newInstance() - .deserializeStream(wrapStream(blockId, stream)) + .deserializeStream(wrapForCompression(blockId, inputStream)) .asIterator.asInstanceOf[Iterator[T]] } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index b9d83495d29b6..ba3e0e395e958 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockFetcherItr = new ShuffleBlockFetcherIterator( + val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, blockManager, mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, - SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) - - // Wrap the streams for compression and encryption based on configuration - val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => - serializerManager.wrapStream(blockId, inputStream) - } + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() // Create a key/value iterator for each stream - val recordIter = wrappedStreams.flatMap { wrappedStream => + val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) => // Note: the asKeyValueIterator below wraps a key/value iterator inside of a // NextIterator. The NextIterator makes sure that close() is called on the // underlying InputStream when all records have been read. @@ -98,8 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Sort the output if there is a sort ordering defined. dep.keyOrdering match { case Some(keyOrd: Ordering[K]) => - // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled, - // the ExternalSorter won't spill to disk. + // Create an ExternalSorter to sort the data. val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) sorter.insertAll(aggregatedIter) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 498c12e196ce0..265a8acfa8d61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskFailedReason} +import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. * * Note that bmAddress can be null. + * + * To prevent user code from hiding this fetch failure, in the constructor we call + * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately + * after creating it -- you cannot create it, check some condition, and then decide to ignore it + * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( bmAddress: BlockManagerId, @@ -45,6 +50,12 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } + // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code + // which intercepts this exception (possibly wrapping it), the Executor can still tell there was + // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option + // because the TaskContext is not defined in some test cases. + Option(TaskContext.get()).map(_.setFetchFailed(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 91858f0912b65..15540485170d0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -61,7 +61,7 @@ private[spark] class IndexShuffleBlockResolver( /** * Remove data file and index file that contain the output data from one map. - * */ + */ def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { @@ -132,7 +132,7 @@ private[spark] class IndexShuffleBlockResolver( * replace them with new ones. * * Note: the `lengths` will be updated to match the existing index file if use the existing ones. - * */ + */ def writeIndexFileAndCommit( shuffleId: Int, mapId: Int, diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 5e977a16febe1..bfb4dc698e325 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -82,13 +82,13 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** - * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Obtains a [[ShuffleHandle]] to pass to tasks. */ override def registerShuffle[K, V, C]( shuffleId: Int, numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) { + if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't // need map-side aggregation, then write numPartitions files directly and just concatenate // them at the end. This avoids doing serialization and deserialization twice to merge diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala index 5c03609e5e5e5..1279b281ad8d8 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllRDDResource.scala @@ -70,7 +70,13 @@ private[spark] object AllRDDResource { address = status.blockManagerId.hostPort, memoryUsed = status.memUsedByRdd(rddId), memoryRemaining = status.memRemaining, - diskUsed = status.diskUsedByRdd(rddId) + diskUsed = status.diskUsedByRdd(rddId), + onHeapMemoryUsed = Some( + if (!rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + offHeapMemoryUsed = Some( + if (rddInfo.storageLevel.useOffHeap) status.memUsedByRdd(rddId) else 0L), + onHeapMemoryRemaining = status.onHeapMemRemaining, + offHeapMemoryRemaining = status.offHeapMemRemaining ) } ) } else { None diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index acb7c23079681..1818935392eb3 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -142,8 +142,10 @@ private[v1] object AllStagesResource { index = uiData.taskInfo.index, attempt = uiData.taskInfo.attemptNumber, launchTime = new Date(uiData.taskInfo.launchTime), + duration = uiData.taskDuration, executorId = uiData.taskInfo.executorId, host = uiData.taskInfo.host, + status = uiData.taskInfo.status, taskLocality = uiData.taskInfo.taskLocality.toString(), speculative = uiData.taskInfo.speculative, accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 17bc04303fa8b..f17b637754826 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -18,6 +18,7 @@ package org.apache.spark.status.api.v1 import java.util.zip.ZipOutputStream import javax.servlet.ServletContext +import javax.servlet.http.HttpServletRequest import javax.ws.rs._ import javax.ws.rs.core.{Context, Response} @@ -40,7 +41,7 @@ import org.apache.spark.ui.SparkUI * HistoryServerSuite. */ @Path("/v1") -private[v1] class ApiRootResource extends UIRootFromServletContext { +private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications") def getApplicationList(): ApplicationListResource = { @@ -56,21 +57,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJobs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs") def getJobs(@PathParam("appId") appId: String): AllJobsResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllJobsResource(ui) } } @Path("applications/{appId}/jobs/{jobId: \\d+}") def getJob(@PathParam("appId") appId: String): OneJobResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneJobResource(ui) } } @@ -79,21 +80,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getJob( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneJobResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneJobResource(ui) } } @Path("applications/{appId}/executors") def getExecutors(@PathParam("appId") appId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new ExecutorListResource(ui) } } @Path("applications/{appId}/allexecutors") def getAllExecutors(@PathParam("appId") appId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllExecutorListResource(ui) } } @@ -102,7 +103,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): ExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new ExecutorListResource(ui) } } @@ -111,15 +112,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getAllExecutors( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllExecutorListResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllExecutorListResource(ui) } } - @Path("applications/{appId}/stages") def getStages(@PathParam("appId") appId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllStagesResource(ui) } } @@ -128,14 +128,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStages( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllStagesResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllStagesResource(ui) } } @Path("applications/{appId}/stages/{stageId: \\d+}") def getStage(@PathParam("appId") appId: String): OneStageResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneStageResource(ui) } } @@ -144,14 +144,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getStage( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneStageResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneStageResource(ui) } } @Path("applications/{appId}/storage/rdd") def getRdds(@PathParam("appId") appId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new AllRDDResource(ui) } } @@ -160,14 +160,14 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdds( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): AllRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new AllRDDResource(ui) } } @Path("applications/{appId}/storage/rdd/{rddId: \\d+}") def getRdd(@PathParam("appId") appId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, None) { ui => + withSparkUI(appId, None) { ui => new OneRDDResource(ui) } } @@ -176,7 +176,7 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { def getRdd( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): OneRDDResource = { - uiRoot.withSparkUI(appId, Some(attemptId)) { ui => + withSparkUI(appId, Some(attemptId)) { ui => new OneRDDResource(ui) } } @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -199,6 +212,21 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { new VersionResource(uiRoot) } + @Path("applications/{appId}/environment") + def getEnvironment(@PathParam("appId") appId: String): ApplicationEnvironmentResource = { + withSparkUI(appId, None) { ui => + new ApplicationEnvironmentResource(ui) + } + } + + @Path("applications/{appId}/{attemptId}/environment") + def getEnvironment( + @PathParam("appId") appId: String, + @PathParam("attemptId") attemptId: String): ApplicationEnvironmentResource = { + withSparkUI(appId, Some(attemptId)) { ui => + new ApplicationEnvironmentResource(ui) + } + } } private[spark] object ApiRootResource { @@ -234,19 +262,6 @@ private[spark] trait UIRoot { .status(Response.Status.SERVICE_UNAVAILABLE) .build() } - - /** - * Get the spark UI with the given appID, and apply a function - * to it. If there is no such app, throw an appropriate exception - */ - def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - getSparkUI(appKey) match { - case Some(ui) => - f(ui) - case None => throw new NotFoundException("no such app: " + appId) - } - } def securityManager: SecurityManager } @@ -263,13 +278,37 @@ private[v1] object UIRootFromServletContext { } } -private[v1] trait UIRootFromServletContext { +private[v1] trait ApiRequestContext { @Context - var servletContext: ServletContext = _ + protected var servletContext: ServletContext = _ + + @Context + protected var httpRequest: HttpServletRequest = _ def uiRoot: UIRoot = UIRootFromServletContext.getUiRoot(servletContext) + + + /** + * Get the spark UI with the given appID, and apply a function + * to it. If there is no such app, throw an appropriate exception + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + uiRoot.getSparkUI(appKey) match { + case Some(ui) => + val user = httpRequest.getRemoteUser() + if (!ui.securityManager.checkUIViewPermissions(user)) { + throw new ForbiddenException(raw"""user "$user" is not authorized""") + } + f(ui) + case None => throw new NotFoundException("no such app: " + appId) + } + } } +private[v1] class ForbiddenException(msg: String) extends WebApplicationException( + Response.status(Response.Status.FORBIDDEN).entity(msg).build()) + private[v1] class NotFoundException(msg: String) extends WebApplicationException( new NoSuchElementException(msg), Response diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala new file mode 100644 index 0000000000000..739a8aceae861 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationEnvironmentResource.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +import org.apache.spark.ui.SparkUI + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class ApplicationEnvironmentResource(ui: SparkUI) { + + @GET + def getEnvironmentInfo(): ApplicationEnvironmentInfo = { + val listener = ui.environmentListener + listener.synchronized { + val jvmInfo = Map(listener.jvmInformation: _*) + val runtime = new RuntimeInfo( + jvmInfo("Java Version"), + jvmInfo("Java Home"), + jvmInfo("Scala Version")) + + new ApplicationEnvironmentInfo( + runtime, + listener.sparkProperties, + listener.systemProperties, + listener.classpathEntries) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 76779290d45e6..f039744e7f67f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -30,6 +30,8 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { @QueryParam("status") status: JList[ApplicationStatus], @DefaultValue("2010-01-01") @QueryParam("minDate") minDate: SimpleDateParam, @DefaultValue("3000-01-01") @QueryParam("maxDate") maxDate: SimpleDateParam, + @DefaultValue("2010-01-01") @QueryParam("minEndDate") minEndDate: SimpleDateParam, + @DefaultValue("3000-01-01") @QueryParam("maxEndDate") maxEndDate: SimpleDateParam, @QueryParam("limit") limit: Integer) : Iterator[ApplicationInfo] = { @@ -43,11 +45,27 @@ private[v1] class ApplicationListResource(uiRoot: UIRoot) { // keep the app if *any* attempts fall in the right time window ((!anyRunning && includeCompleted) || (anyRunning && includeRunning)) && app.attempts.exists { attempt => - val start = attempt.startTime.getTime - start >= minDate.timestamp && start <= maxDate.timestamp + isAttemptInRange(attempt, minDate, maxDate, minEndDate, maxEndDate, anyRunning) } }.take(numApps) } + + private def isAttemptInRange( + attempt: ApplicationAttemptInfo, + minStartDate: SimpleDateParam, + maxStartDate: SimpleDateParam, + minEndDate: SimpleDateParam, + maxEndDate: SimpleDateParam, + anyRunning: Boolean): Boolean = { + val startTimeOk = attempt.startTime.getTime >= minStartDate.timestamp && + attempt.startTime.getTime <= maxStartDate.timestamp + // If the maxEndDate is in the past, exclude all running apps. + val endTimeOkForRunning = anyRunning && (maxEndDate.timestamp > System.currentTimeMillis()) + val endTimeOkForCompleted = !anyRunning && (attempt.endTime.getTime >= minEndDate.timestamp && + attempt.endTime.getTime <= maxEndDate.timestamp) + val endTimeOk = endTimeOkForRunning || endTimeOkForCompleted + startTimeOk && endTimeOk + } } private[spark] object ApplicationsListResource { @@ -72,7 +90,8 @@ private[spark] object ApplicationsListResource { }, lastUpdated = new Date(internalAttemptInfo.lastUpdated), sparkUser = internalAttemptInfo.sparkUser, - completed = internalAttemptInfo.completed + completed = internalAttemptInfo.completed, + appSparkVersion = internalAttemptInfo.appSparkVersion ) } ) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala index 6ca59c2f3caeb..ab53881594180 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ExecutorListResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{GET, PathParam, Produces} +import javax.ws.rs.{GET, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala index f6a9f9c5573db..76af33c1a18db 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala @@ -21,7 +21,7 @@ import java.lang.annotation.Annotation import java.lang.reflect.Type import java.nio.charset.StandardCharsets import java.text.SimpleDateFormat -import java.util.{Calendar, SimpleTimeZone} +import java.util.{Calendar, Locale, SimpleTimeZone} import javax.ws.rs.Produces import javax.ws.rs.core.{MediaType, MultivaluedMap} import javax.ws.rs.ext.{MessageBodyWriter, Provider} @@ -86,7 +86,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{ private[spark] object JacksonMessageWriter { def makeISODateFormat: SimpleDateFormat = { - val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'") + val iso8601 = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS'GMT'", Locale.US) val cal = Calendar.getInstance(new SimpleTimeZone(0, "GMT")) iso8601.setCalendar(cal) iso8601 diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala index b4a991eda35f3..1cd37185d6601 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SecurityFilter.scala @@ -21,14 +21,14 @@ import javax.ws.rs.core.Response import javax.ws.rs.ext.Provider @Provider -private[v1] class SecurityFilter extends ContainerRequestFilter with UIRootFromServletContext { +private[v1] class SecurityFilter extends ContainerRequestFilter with ApiRequestContext { override def filter(req: ContainerRequestContext): Unit = { - val user = Option(req.getSecurityContext.getUserPrincipal).map { _.getName }.orNull + val user = httpRequest.getRemoteUser() if (!uiRoot.securityManager.checkUIViewPermissions(user)) { req.abortWith( Response .status(Response.Status.FORBIDDEN) - .entity(raw"""user "$user"is not authorized""") + .entity(raw"""user "$user" is not authorized""") .build() ) } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala index 0c71cd2382225..d8d5e8958b23c 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/SimpleDateParam.scala @@ -17,7 +17,7 @@ package org.apache.spark.status.api.v1 import java.text.{ParseException, SimpleDateFormat} -import java.util.TimeZone +import java.util.{Locale, TimeZone} import javax.ws.rs.WebApplicationException import javax.ws.rs.core.Response import javax.ws.rs.core.Response.Status @@ -25,12 +25,12 @@ import javax.ws.rs.core.Response.Status private[v1] class SimpleDateParam(val originalValue: String) { val timestamp: Long = { - val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz") + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSz", Locale.US) try { format.parse(originalValue).getTime() } catch { case _: ParseException => - val gmtDay = new SimpleDateFormat("yyyy-MM-dd") + val gmtDay = new SimpleDateFormat("yyyy-MM-dd", Locale.US) gmtDay.setTimeZone(TimeZone.getTimeZone("GMT")) try { gmtDay.parse(originalValue).getTime() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 44a929b310384..f6203271f3cd2 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -38,7 +38,8 @@ class ApplicationAttemptInfo private[spark]( val lastUpdated: Date, val duration: Long, val sparkUser: String, - val completed: Boolean = false) { + val completed: Boolean = false, + val appSparkVersion: String) { def getStartTimeEpoch: Long = startTime.getTime def getEndTimeEpoch: Long = endTime.getTime def getLastUpdatedEpoch: Long = lastUpdated.getTime @@ -73,8 +74,16 @@ class ExecutorSummary private[spark]( val totalInputBytes: Long, val totalShuffleRead: Long, val totalShuffleWrite: Long, + val isBlacklisted: Boolean, val maxMemory: Long, - val executorLogs: Map[String, String]) + val executorLogs: Map[String, String], + val memoryMetrics: Option[MemoryMetrics]) + +class MemoryMetrics private[spark]( + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) class JobData private[spark]( val jobId: Int, @@ -110,7 +119,11 @@ class RDDDataDistribution private[spark]( val address: String, val memoryUsed: Long, val memoryRemaining: Long, - val diskUsed: Long) + val diskUsed: Long, + val onHeapMemoryUsed: Option[Long], + val offHeapMemoryUsed: Option[Long], + val onHeapMemoryRemaining: Option[Long], + val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( val blockName: String, @@ -157,8 +170,10 @@ class TaskData private[spark]( val index: Int, val attempt: Int, val launchTime: Date, + val duration: Option[Long] = None, val executorId: String, val host: String, + val status: String, val taskLocality: String, val speculative: Boolean, val accumulatorUpdates: Seq[AccumulableInfo], @@ -249,3 +264,14 @@ class AccumulableInfo private[spark]( class VersionInfo private[spark]( val spark: String) + +class ApplicationEnvironmentInfo private[spark] ( + val runtime: RuntimeInfo, + val sparkProperties: Seq[(String, String)], + val systemProperties: Seq[(String, String)], + val classpathEntries: Seq[(String, String)]) + +class RuntimeInfo private[spark]( + val javaVersion: String, + val javaHome: String, + val scalaVersion: String) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index dd8f5bacb9f6e..3db59837fbebd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag -import com.google.common.collect.ConcurrentHashMultiset +import com.google.common.collect.{ConcurrentHashMultiset, ImmutableMultiset} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -340,7 +340,7 @@ private[storage] class BlockInfoManager extends Logging { val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]() val readLocks = synchronized { - readLocksByTask.remove(taskAttemptId).get + readLocksByTask.remove(taskAttemptId).getOrElse(ImmutableMultiset.of[BlockId]()) } val writeLocks = synchronized { writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty) @@ -371,6 +371,12 @@ private[storage] class BlockInfoManager extends Logging { blocksWithReleasedLocks } + /** Returns the number of locks held by the given task. Used only for testing. */ + private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = { + readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) + + writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0) + } + /** * Returns the number of blocks tracked. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b83324e0fc..b3e458448974f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer +import java.nio.channels.Channels import scala.collection.mutable import scala.collection.mutable.HashMap @@ -33,7 +34,7 @@ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -45,13 +46,61 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer - /* Class for returning a fetched block and associated metrics. */ private[spark] class BlockResult( val data: Iterator[Any], val readMethod: DataReadMethod.Value, val bytes: Long) +/** + * Abstracts away how blocks are stored and provides different ways to read the underlying block + * data. Callers should call [[dispose()]] when they're done with the block. + */ +private[spark] trait BlockData { + + def toInputStream(): InputStream + + /** + * Returns a Netty-friendly wrapper for the block's data. + * + * Please see `ManagedBuffer.convertToNetty()` for more details. + */ + def toNetty(): Object + + def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer + + def toByteBuffer(): ByteBuffer + + def size: Long + + def dispose(): Unit + +} + +private[spark] class ByteBufferBlockData( + val buffer: ChunkedByteBuffer, + val shouldDispose: Boolean) extends BlockData { + + override def toInputStream(): InputStream = buffer.toInputStream(dispose = false) + + override def toNetty(): Object = buffer.toNetty + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + buffer.copy(allocator) + } + + override def toByteBuffer(): ByteBuffer = buffer.toByteBuffer + + override def size: Long = buffer.size + + override def dispose(): Unit = { + if (shouldDispose) { + buffer.dispose() + } + } + +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -62,7 +111,7 @@ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - serializerManager: SerializerManager, + val serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -91,15 +140,15 @@ private[spark] class BlockManager( // Actual storage of where blocks are kept private[spark] val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memoryManager, this) - private[spark] val diskStore = new DiskStore(conf, diskBlockManager) + private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager) memoryManager.setMemoryStore(memoryStore) // Note: depending on the memory manager, `maxMemory` may actually vary over time. // However, since we use this only for reporting and logging, what we actually want here is // the absolute maximum value that `maxMemory` can ever possibly reach. We may need // to revisit whether reporting this value as the "max" is intuitive to the user. - private val maxMemory = - memoryManager.maxOnHeapStorageMemory + memoryManager.maxOffHeapStorageMemory + private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory + private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory // Port used by the external shuffle service. In Yarn mode, this may be already be // set through the Hadoop configuration as the server is launched in the Yarn NM. @@ -125,8 +174,7 @@ private[spark] class BlockManager( // standard BlockTransferService to directly connect to other Executors. private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), - securityManager.isSaslEncryptionEnabled()) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) } else { blockTransferService } @@ -178,7 +226,8 @@ private[spark] class BlockManager( val idFromMaster = master.registerBlockManager( id, - maxMemory, + maxOnHeapMemory, + maxOffHeapMemory, slaveEndpoint) blockManagerId = if (idFromMaster != null) idFromMaster else id @@ -256,7 +305,7 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo(s"BlockManager $blockManagerId re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager(blockManagerId, maxOnHeapMemory, maxOffHeapMemory, slaveEndpoint) reportAllBlocks() } @@ -302,7 +351,8 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { - case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) + case Some(blockData) => + new BlockManagerManagedBuffer(blockInfoManager, blockId, blockData, true) case None => // If this block manager receives a request for a block that it doesn't have then it's // likely that the master has outdated block statuses for this block. Therefore, we send @@ -315,6 +365,9 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. + * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. */ override def putBlockData( blockId: BlockId, @@ -458,21 +511,22 @@ private[spark] class BlockManager( val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) } else if (level.useDisk && diskStore.contains(blockId)) { + val diskData = diskStore.getBytes(blockId) val iterToReturn: Iterator[Any] = { - val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { val diskValues = serializerManager.dataDeserializeStream( blockId, - diskBytes.toInputStream(dispose = true))(info.classTag) + diskData.toInputStream())(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - .map {_.toInputStream(dispose = false)} - .getOrElse { diskBytes.toInputStream(dispose = true) } + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map { _.toInputStream(dispose = false) } + .getOrElse { diskData.toInputStream() } serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } - val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) + val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, + releaseLockAndDispose(blockId, diskData)) Some(new BlockResult(ci, DataReadMethod.Disk, info.size)) } else { handleLocalReadFailure(blockId) @@ -483,7 +537,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ChunkedByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[BlockData] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -491,9 +545,9 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) + val buf = new ChunkedByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } } @@ -505,7 +559,7 @@ private[spark] class BlockManager( * Must be called while holding a read lock on the block. * Releases the read lock upon exception; keeps the read lock upon successful return. */ - private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): ChunkedByteBuffer = { + private def doGetLocalBytes(blockId: BlockId, info: BlockInfo): BlockData = { val level = info.level logDebug(s"Level for block $blockId is $level") // In order, try to read the serialized bytes from memory, then from disk, then fall back to @@ -520,17 +574,19 @@ private[spark] class BlockManager( diskStore.getBytes(blockId) } else if (level.useMemory && memoryStore.contains(blockId)) { // The block was not found on disk, so serialize an in-memory copy: - serializerManager.dataSerializeWithExplicitClassTag( - blockId, memoryStore.getValues(blockId).get, info.classTag) + new ByteBufferBlockData(serializerManager.dataSerializeWithExplicitClassTag( + blockId, memoryStore.getValues(blockId).get, info.classTag), true) } else { handleLocalReadFailure(blockId) } } else { // storage level is serialized if (level.useMemory && memoryStore.contains(blockId)) { - memoryStore.getBytes(blockId).get + new ByteBufferBlockData(memoryStore.getBytes(blockId).get, false) } else if (level.useDisk && diskStore.contains(blockId)) { - val diskBytes = diskStore.getBytes(blockId) - maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) + val diskData = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskData) + .map(new ByteBufferBlockData(_, false)) + .getOrElse(diskData) } else { handleLocalReadFailure(blockId) } @@ -553,12 +609,19 @@ private[spark] class BlockManager( /** * Return a list of locations for the given block, prioritizing the local machine since - * multiple block managers can share the same host. + * multiple block managers can share the same host, followed by hosts on the same rack. */ private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { val locs = Random.shuffle(master.getLocations(blockId)) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - preferredLocs ++ otherLocs + blockManagerId.topologyInfo match { + case None => preferredLocs ++ otherLocs + case Some(_) => + val (sameRackLocs, differentRackLocs) = otherLocs.partition { + loc => blockManagerId.topologyInfo == loc.topologyInfo + } + preferredLocs ++ sameRackLocs ++ differentRackLocs + } } /** @@ -745,15 +808,17 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, + new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) } /** * Put a new block of serialized bytes to the block manager. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @return true if the block was stored or false if an error occurred. */ def putBytes[T: ClassTag]( @@ -771,6 +836,9 @@ private[spark] class BlockManager( * * If the block already exists, this method will not overwrite it. * + * '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing + * so may corrupt or change the data stored by the `BlockManager`. + * * @param keepReadLock if true, this method will hold the read lock when it returns (even if the * block already exists). If false, this method will hold no locks when it * returns. @@ -790,8 +858,9 @@ private[spark] class BlockManager( val replicationFuture = if (level.replication > 1) { Future { // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bytes, level, classTag) + // thread pool. The ByteBufferBlockData wrapper is not disposed of to avoid releasing + // buffers that are owned by the caller. + replicate(blockId, new ByteBufferBlockData(bytes, false), level, classTag) }(futureExecutionContext) } else { null @@ -814,7 +883,15 @@ private[spark] class BlockManager( false } } else { - memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) + val memoryMode = level.memoryMode + memoryStore.putBytes(blockId, size, memoryMode, () => { + if (memoryMode == MemoryMode.OFF_HEAP && + bytes.chunks.exists(buffer => !buffer.isDirect)) { + bytes.copy(Platform.allocateDirectBuffer) + } else { + bytes + } + }) } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") @@ -962,8 +1039,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iter)(classTag) } size = diskStore.getSize(blockId) } else { @@ -978,8 +1056,9 @@ private[spark] class BlockManager( // Not enough space to unroll this block; drop to disk if applicable if (level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") - diskStore.put(blockId) { fileOutputStream => - partiallySerializedValues.finishWritingToStream(fileOutputStream) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + partiallySerializedValues.finishWritingToStream(out) } size = diskStore.getSize(blockId) } else { @@ -989,8 +1068,9 @@ private[spark] class BlockManager( } } else if (level.useDisk) { - diskStore.put(blockId) { fileOutputStream => - serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag) + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) + serializerManager.dataSerializeStream(blockId, out, iterator())(classTag) } size = diskStore.getSize(blockId) } @@ -1043,29 +1123,29 @@ private[spark] class BlockManager( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { + diskData: BlockData): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to // put values read from disk into the MemoryStore. blockInfo.synchronized { if (memoryStore.contains(blockId)) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { val allocator = level.memoryMode match { case MemoryMode.ON_HEAP => ByteBuffer.allocate _ case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ } - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { + val putSucceeded = memoryStore.putBytes(blockId, diskData.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy(allocator) + diskData.toChunkedByteBuffer(allocator) }) if (putSucceeded) { - diskBytes.dispose() + diskData.dispose() Some(memoryStore.getBytes(blockId).get) } else { None @@ -1130,15 +1210,48 @@ private[spark] class BlockManager( } } + /** + * Called for pro-active replenishment of blocks lost due to executor failures + * + * @param blockId blockId being replicate + * @param existingReplicas existing block managers that have a replica + * @param maxReplicas maximum replicas needed + */ + def replicateBlock( + blockId: BlockId, + existingReplicas: Set[BlockManagerId], + maxReplicas: Int): Unit = { + logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") + blockInfoManager.lockForReading(blockId).foreach { info => + val data = doGetLocalBytes(blockId, info) + val storageLevel = StorageLevel( + useDisk = info.level.useDisk, + useMemory = info.level.useMemory, + useOffHeap = info.level.useOffHeap, + deserialized = info.level.deserialized, + replication = maxReplicas) + // we know we are called as a result of an executor removal, so we refresh peer cache + // this way, we won't try to replicate to a missing executor with a stale reference + getPeers(forceFetch = true) + try { + replicate(blockId, data, storageLevel, info.classTag, existingReplicas) + } finally { + logDebug(s"Releasing lock for $blockId") + releaseLockAndDispose(blockId, data) + } + } + } + /** * Replicate block to another node. Note that this is a blocking call that returns after * the block has been replicated. */ private def replicate( blockId: BlockId, - data: ChunkedByteBuffer, + data: BlockData, level: StorageLevel, - classTag: ClassTag[_]): Unit = { + classTag: ClassTag[_], + existingReplicas: Set[BlockManagerId] = Set.empty): Unit = { val maxReplicationFailures = conf.getInt("spark.storage.maxReplicationFailures", 1) val tLevel = StorageLevel( @@ -1149,23 +1262,24 @@ private[spark] class BlockManager( replication = 1) val numPeersToReplicateTo = level.replication - 1 - val startTime = System.nanoTime - var peersReplicatedTo = mutable.HashSet.empty[BlockManagerId] + var peersReplicatedTo = mutable.HashSet.empty ++ existingReplicas var peersFailedToReplicateTo = mutable.HashSet.empty[BlockManagerId] var numFailures = 0 + val initialPeers = getPeers(false).filterNot(existingReplicas.contains(_)) + var peersForReplication = blockReplicationPolicy.prioritize( blockManagerId, - getPeers(false), - mutable.HashSet.empty, + initialPeers, + peersReplicatedTo, blockId, numPeersToReplicateTo) while(numFailures <= maxReplicationFailures && - !peersForReplication.isEmpty && - peersReplicatedTo.size != numPeersToReplicateTo) { + !peersForReplication.isEmpty && + peersReplicatedTo.size < numPeersToReplicateTo) { val peer = peersForReplication.head try { val onePeerStartTime = System.nanoTime @@ -1175,7 +1289,7 @@ private[spark] class BlockManager( peer.port, peer.executorId, blockId, - new NettyManagedBuffer(data.toNetty), + new BlockManagerManagedBuffer(blockInfoManager, blockId, data, false), tLevel, classTag) logTrace(s"Replicated $blockId of ${data.size} bytes to $peer" + @@ -1202,7 +1316,6 @@ private[spark] class BlockManager( numPeersToReplicateTo - peersReplicatedTo.size) } } - logDebug(s"Replicating $blockId of ${data.size} bytes to " + s"${peersReplicatedTo.size} peer(s) took ${(System.nanoTime - startTime) / 1e6} ms") if (peersReplicatedTo.size < numPeersToReplicateTo) { @@ -1258,10 +1371,11 @@ private[spark] class BlockManager( logInfo(s"Writing block $blockId to disk") data() match { case Left(elements) => - diskStore.put(blockId) { fileOutputStream => + diskStore.put(blockId) { channel => + val out = Channels.newOutputStream(channel) serializerManager.dataSerializeStream( blockId, - fileOutputStream, + out, elements.toIterator)(info.classTag.asInstanceOf[ClassTag[T]]) } case Right(bytes) => @@ -1353,6 +1467,11 @@ private[spark] class BlockManager( } } + def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = { + blockInfoManager.unlock(blockId) + data.dispose() + } + def stop(): Unit = { blockTransferService.close() if (shuffleClient ne blockTransferService) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index c37a3604d28fa..2c3da0ee85e06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -46,7 +46,7 @@ class BlockManagerId private ( def executorId: String = executorId_ if (null != host_) { - Utils.checkHost(host_, "Expected hostname") + Utils.checkHost(host_) assert (port_ > 0) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala index f66f942798550..1ea0d378cbe87 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala @@ -17,31 +17,52 @@ package org.apache.spark.storage -import org.apache.spark.network.buffer.{ManagedBuffer, NettyManagedBuffer} +import java.io.InputStream +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.util.io.ChunkedByteBuffer /** - * This [[ManagedBuffer]] wraps a [[ChunkedByteBuffer]] retrieved from the [[BlockManager]] + * This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]] * so that the corresponding block's read lock can be released once this buffer's references * are released. * + * If `dispose` is set to true, the [[BlockData]]will be disposed when the buffer's reference + * count drops to zero. + * * This is effectively a wrapper / bridge to connect the BlockManager's notion of read locks * to the network layer's notion of retain / release counts. */ private[storage] class BlockManagerManagedBuffer( blockInfoManager: BlockInfoManager, blockId: BlockId, - chunkedBuffer: ChunkedByteBuffer) extends NettyManagedBuffer(chunkedBuffer.toNetty) { + data: BlockData, + dispose: Boolean) extends ManagedBuffer { + + private val refCount = new AtomicInteger(1) + + override def size(): Long = data.size + + override def nioByteBuffer(): ByteBuffer = data.toByteBuffer() + + override def createInputStream(): InputStream = data.toInputStream() + + override def convertToNetty(): Object = data.toNetty() override def retain(): ManagedBuffer = { - super.retain() + refCount.incrementAndGet() val locked = blockInfoManager.lockForReading(blockId, blocking = false) assert(locked.isDefined) this - } + } override def release(): ManagedBuffer = { blockInfoManager.unlock(blockId) - super.release() + if (refCount.decrementAndGet() == 0 && dispose) { + data.dispose() + } + this } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 7a600068912b1..ea5d8423a588c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -57,11 +57,12 @@ class BlockManagerMaster( */ def registerBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { logInfo(s"Registering BlockManager $blockManagerId") - val updatedId = driverEndpoint.askWithRetry[BlockManagerId]( - RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + val updatedId = driverEndpoint.askSync[BlockManagerId]( + RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) logInfo(s"Registered BlockManager $updatedId") updatedId } @@ -72,7 +73,7 @@ class BlockManagerMaster( storageLevel: StorageLevel, memSize: Long, diskSize: Long): Boolean = { - val res = driverEndpoint.askWithRetry[Boolean]( + val res = driverEndpoint.askSync[Boolean]( UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) logDebug(s"Updated info of block $blockId") res @@ -80,12 +81,12 @@ class BlockManagerMaster( /** Get locations of the blockId from the driver */ def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetLocations(blockId)) + driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId)) } /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { - driverEndpoint.askWithRetry[IndexedSeq[Seq[BlockManagerId]]]( + driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]]( GetLocationsMultipleBlockIds(blockIds)) } @@ -99,11 +100,11 @@ class BlockManagerMaster( /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = { - driverEndpoint.askWithRetry[Seq[BlockManagerId]](GetPeers(blockManagerId)) + driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId)) } def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = { - driverEndpoint.askWithRetry[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) + driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId)) } /** @@ -111,12 +112,12 @@ class BlockManagerMaster( * blocks that the driver knows about. */ def removeBlock(blockId: BlockId) { - driverEndpoint.askWithRetry[Boolean](RemoveBlock(blockId)) + driverEndpoint.askSync[Boolean](RemoveBlock(blockId)) } /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) + val future = driverEndpoint.askSync[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}", e) @@ -128,7 +129,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int, blocking: Boolean) { - val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + val future = driverEndpoint.askSync[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}", e) @@ -140,7 +141,7 @@ class BlockManagerMaster( /** Remove all blocks belonging to the given broadcast. */ def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { - val future = driverEndpoint.askWithRetry[Future[Seq[Int]]]( + val future = driverEndpoint.askSync[Future[Seq[Int]]]( RemoveBroadcast(broadcastId, removeFromMaster)) future.onFailure { case e: Exception => @@ -159,11 +160,11 @@ class BlockManagerMaster( * amount of remaining memory. */ def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = { - driverEndpoint.askWithRetry[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) + driverEndpoint.askSync[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus) } def getStorageStatus: Array[StorageStatus] = { - driverEndpoint.askWithRetry[Array[StorageStatus]](GetStorageStatus) + driverEndpoint.askSync[Array[StorageStatus]](GetStorageStatus) } /** @@ -184,7 +185,7 @@ class BlockManagerMaster( * master endpoint for a response to a prior message. */ val response = driverEndpoint. - askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + askSync[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip implicit val sameThread = ThreadUtils.sameThread val cbf = @@ -214,7 +215,7 @@ class BlockManagerMaster( filter: BlockId => Boolean, askSlaves: Boolean): Seq[BlockId] = { val msg = GetMatchingBlockIds(filter, askSlaves) - val future = driverEndpoint.askWithRetry[Future[Seq[BlockId]]](msg) + val future = driverEndpoint.askSync[Future[Seq[BlockId]]](msg) timeout.awaitResult(future) } @@ -223,7 +224,7 @@ class BlockManagerMaster( * since they are not reported the master. */ def hasCachedBlocks(executorId: String): Boolean = { - driverEndpoint.askWithRetry[Boolean](HasCachedBlocks(executorId)) + driverEndpoint.askSync[Boolean](HasCachedBlocks(executorId)) } /** Stop the driver endpoint, called only on the Spark driver node */ @@ -237,7 +238,7 @@ class BlockManagerMaster( /** Send a one-way message to the master endpoint, to which we expect it to reply with true. */ private def tell(message: Any) { - if (!driverEndpoint.askWithRetry[Boolean](message)) { + if (!driverEndpoint.askSync[Boolean](message)) { throw new SparkException("BlockManagerMasterEndpoint returned false, expected true.") } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 145c434a4f0cf..6f85b9e4d6c73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -22,6 +22,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} +import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi @@ -65,11 +66,13 @@ class BlockManagerMasterEndpoint( mapper } + val proactivelyReplicate = conf.get("spark.storage.replication.proactive", "false").toBoolean + logInfo("BlockManagerMasterEndpoint up") override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - context.reply(register(blockManagerId, maxMemSize, slaveEndpoint)) + case RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) => + context.reply(register(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint)) case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => @@ -195,17 +198,38 @@ class BlockManagerMasterEndpoint( // Remove it from blockManagerInfo and remove all the blocks. blockManagerInfo.remove(blockManagerId) + val iterator = info.blocks.keySet.iterator while (iterator.hasNext) { val blockId = iterator.next val locations = blockLocations.get(blockId) locations -= blockManagerId + // De-register the block if none of the block managers have it. Otherwise, if pro-active + // replication is enabled, and a block is either an RDD or a test block (the latter is used + // for unit testing), we send a message to a randomly chosen executor location to replicate + // the given block. Note that we ignore other block types (such as broadcast/shuffle blocks + // etc.) as replication doesn't make much sense in that context. if (locations.size == 0) { blockLocations.remove(blockId) + logWarning(s"No more replicas available for $blockId !") + } else if (proactivelyReplicate && (blockId.isRDD || blockId.isInstanceOf[TestBlockId])) { + // As a heursitic, assume single executor failure to find out the number of replicas that + // existed before failure + val maxReplicas = locations.size + 1 + val i = (new Random(blockId.hashCode)).nextInt(locations.size) + val blockLocations = locations.toSeq + val candidateBMId = blockLocations(i) + blockManagerInfo.get(candidateBMId).foreach { bm => + val remainingLocations = locations.toSeq.filter(bm => bm != candidateBMId) + val replicateMsg = ReplicateBlock(blockId, remainingLocations, maxReplicas) + bm.slaveEndpoint.ask[Boolean](replicateMsg) + } } } + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) logInfo(s"Removing block manager $blockManagerId") + } private def removeExecutor(execId: String) { @@ -252,7 +276,8 @@ class BlockManagerMasterEndpoint( private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => - new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) + new StorageStatus(blockManagerId, info.maxMem, Some(info.maxOnHeapMem), + Some(info.maxOffHeapMem), info.blocks.asScala) }.toArray } @@ -314,7 +339,8 @@ class BlockManagerMasterEndpoint( */ private def register( idWithoutTopologyInfo: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, slaveEndpoint: RpcEndpointRef): BlockManagerId = { // the dummy id is not expected to contain the topology information. // we get that info here and respond back with a more fleshed out block manager id @@ -335,14 +361,15 @@ class BlockManagerMasterEndpoint( case None => } logInfo("Registering block manager %s with %s RAM, %s".format( - id.hostPort, Utils.bytesToString(maxMemSize), id)) + id.hostPort, Utils.bytesToString(maxOnHeapMemSize + maxOffHeapMemSize), id)) blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) } - listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, + Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) id } @@ -440,10 +467,13 @@ object BlockStatus { private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, - val maxMem: Long, + val maxOnHeapMem: Long, + val maxOffHeapMem: Long, val slaveEndpoint: RpcEndpointRef) extends Logging { + val maxMem = maxOnHeapMem + maxOffHeapMem + private var _lastSeenMs: Long = timeMs private var _remainingMem: Long = maxMem @@ -467,11 +497,17 @@ private[spark] class BlockManagerInfo( updateLastSeenMs() - if (_blocks.containsKey(blockId)) { + val blockExists = _blocks.containsKey(blockId) + var originalMemSize: Long = 0 + var originalDiskSize: Long = 0 + var originalLevel: StorageLevel = StorageLevel.NONE + + if (blockExists) { // The block exists on the slave already. val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize + originalLevel = blockStatus.storageLevel + originalMemSize = blockStatus.memSize + originalDiskSize = blockStatus.diskSize if (originalLevel.useMemory) { _remainingMem += originalMemSize @@ -490,32 +526,44 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) + if (blockExists) { + logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(memSize)}," + + s" original size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } else { + logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(memSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + if (blockExists) { + logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(diskSize)}," + + s" original size: ${Utils.bytesToString(originalDiskSize)})") + } else { + logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(diskSize)})") + } } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } - } else if (_blocks.containsKey(blockId)) { + } else if (blockExists) { // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) _cachedBlocks -= blockId - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) + if (originalLevel.useMemory) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + + s" (size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + if (originalLevel.useDisk) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + + s" (size: ${Utils.bytesToString(originalDiskSize)})") } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 6bded92700504..0c0ff144596ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -32,6 +32,10 @@ private[spark] object BlockManagerMessages { // blocks that the master knows about. case class RemoveBlock(blockId: BlockId) extends ToBlockManagerSlave + // Replicate blocks that were lost due to executor failure + case class ReplicateBlock(blockId: BlockId, replicas: Seq[BlockManagerId], maxReplicas: Int) + extends ToBlockManagerSlave + // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave @@ -43,7 +47,7 @@ private[spark] object BlockManagerMessages { extends ToBlockManagerSlave /** - * Driver -> Executor message to trigger a thread dump. + * Driver to Executor message to trigger a thread dump. */ case object TriggerThreadDump extends ToBlockManagerSlave @@ -54,7 +58,8 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, - maxMemSize: Long, + maxOnHeapMemSize: Long, + maxOffHeapMemSize: Long, sender: RpcEndpointRef) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index d17ddbc162579..1aaa42459df69 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -74,6 +74,10 @@ class BlockManagerSlaveEndpoint( case TriggerThreadDump => context.reply(Utils.getThreadDump()) + + case ReplicateBlock(blockId, replicas, maxReplicas) => + context.reply(blockManager.replicateBlock(blockId, replicas.toSet, maxReplicas)) + } private def doAsync[T](actionMessage: String, context: RpcCallContext)(body: => T) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index c5ba9af3e2658..197a01762c0c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -26,35 +26,39 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager) override val metricRegistry = new MetricRegistry() override val sourceName = "BlockManager" - metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - maxMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val remainingMem = storageStatusList.map(_.memRemaining).sum - remainingMem / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val memUsed = storageStatusList.map(_.memUsed).sum - memUsed / 1024 / 1024 - } - }) - - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { - override def getValue: Long = { - val storageStatusList = blockManager.master.getStorageStatus - val diskSpaceUsed = storageStatusList.map(_.diskUsed).sum - diskSpaceUsed / 1024 / 1024 - } - }) + private def registerGauge(name: String, func: BlockManagerMaster => Long): Unit = { + metricRegistry.register(name, new Gauge[Long] { + override def getValue: Long = func(blockManager.master) / 1024 / 1024 + }) + } + + registerGauge(MetricRegistry.name("memory", "maxMem_MB"), + _.getStorageStatus.map(_.maxMem).sum) + + registerGauge(MetricRegistry.name("memory", "maxOnHeapMem_MB"), + _.getStorageStatus.map(_.maxOnHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "maxOffHeapMem_MB"), + _.getStorageStatus.map(_.maxOffHeapMem.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingMem_MB"), + _.getStorageStatus.map(_.memRemaining).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOnHeapMem_MB"), + _.getStorageStatus.map(_.onHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "remainingOffHeapMem_MB"), + _.getStorageStatus.map(_.offHeapMemRemaining.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "memUsed_MB"), + _.getStorageStatus.map(_.memUsed).sum) + + registerGauge(MetricRegistry.name("memory", "onHeapMemUsed_MB"), + _.getStorageStatus.map(_.onHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("memory", "offHeapMemUsed_MB"), + _.getStorageStatus.map(_.offHeapMemUsed.getOrElse(0L)).sum) + + registerGauge(MetricRegistry.name("disk", "diskSpaceUsed_MB"), + _.getStorageStatus.map(_.diskUsed).sum) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala index bf087af16a5b1..353eac60df171 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockReplicationPolicy.scala @@ -53,6 +53,46 @@ trait BlockReplicationPolicy { numReplicas: Int): List[BlockManagerId] } +object BlockReplicationUtils { + // scalastyle:off line.size.limit + /** + * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while + * minimizing space usage. Please see + * here. + * + * @param n total number of indices + * @param m number of samples needed + * @param r random number generator + * @return list of m random unique indices + */ + // scalastyle:on line.size.limit + private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { + val indices = (n - m + 1 to n).foldLeft(mutable.LinkedHashSet.empty[Int]) {case (set, i) => + val t = r.nextInt(i) + 1 + if (set.contains(t)) set + i else set + t + } + indices.map(_ - 1).toList + } + + /** + * Get a random sample of size m from the elems + * + * @param elems + * @param m number of samples needed + * @param r random number generator + * @tparam T + * @return a random list of size m. If there are fewer than m elements in elems, we just + * randomly shuffle elems + */ + def getRandomSample[T](elems: Seq[T], m: Int, r: Random): List[T] = { + if (elems.size > m) { + getSampleIds(elems.size, m, r).map(elems(_)) + } else { + r.shuffle(elems).toList + } + } +} + @DeveloperApi class RandomBlockReplicationPolicy extends BlockReplicationPolicy @@ -67,6 +107,7 @@ class RandomBlockReplicationPolicy * @param peersReplicatedTo Set of peers already replicated to * @param blockId BlockId of the block being replicated. This can be used as a source of * randomness if needed. + * @param numReplicas Number of peers we need to replicate to * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ override def prioritize( @@ -78,7 +119,7 @@ class RandomBlockReplicationPolicy val random = new Random(blockId.hashCode) logDebug(s"Input peers : ${peers.mkString(", ")}") val prioritizedPeers = if (peers.size > numReplicas) { - getSampleIds(peers.size, numReplicas, random).map(peers(_)) + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) } else { if (peers.size < numReplicas) { logWarning(s"Expecting ${numReplicas} replicas with only ${peers.size} peer/s.") @@ -88,25 +129,96 @@ class RandomBlockReplicationPolicy logDebug(s"Prioritized peers : ${prioritizedPeers.mkString(", ")}") prioritizedPeers } +} + +@DeveloperApi +class BasicBlockReplicationPolicy + extends BlockReplicationPolicy + with Logging { /** - * Uses sampling algorithm by Robert Floyd. Finds a random sample in O(n) while - * minimizing space usage - * [[http://math.stackexchange.com/questions/178690/ - * whats-the-proof-of-correctness-for-robert-floyds-algorithm-for-selecting-a-sin]] + * Method to prioritize a bunch of candidate peers of a block manager. This implementation + * replicates the behavior of block replication in HDFS. For a given number of replicas needed, + * we choose a peer within the rack, one outside and remaining blockmanagers are chosen at + * random, in that order till we meet the number of replicas needed. + * This works best with a total replication factor of 3, like HDFS. * - * @param n total number of indices - * @param m number of samples needed - * @param r random number generator - * @return list of m random unique indices + * @param blockManagerId Id of the current BlockManager for self identification + * @param peers A list of peers of a BlockManager + * @param peersReplicatedTo Set of peers already replicated to + * @param blockId BlockId of the block being replicated. This can be used as a source of + * randomness if needed. + * @param numReplicas Number of peers we need to replicate to + * @return A prioritized list of peers. Lower the index of a peer, higher its priority */ - private def getSampleIds(n: Int, m: Int, r: Random): List[Int] = { - val indices = (n - m + 1 to n).foldLeft(Set.empty[Int]) {case (set, i) => - val t = r.nextInt(i) + 1 - if (set.contains(t)) set + i else set + t + override def prioritize( + blockManagerId: BlockManagerId, + peers: Seq[BlockManagerId], + peersReplicatedTo: mutable.HashSet[BlockManagerId], + blockId: BlockId, + numReplicas: Int): List[BlockManagerId] = { + + logDebug(s"Input peers : $peers") + logDebug(s"BlockManagerId : $blockManagerId") + + val random = new Random(blockId.hashCode) + + // if block doesn't have topology info, we can't do much, so we randomly shuffle + // if there is, we see what's needed from peersReplicatedTo and based on numReplicas, + // we choose whats needed + if (blockManagerId.topologyInfo.isEmpty || numReplicas == 0) { + // no topology info for the block. The best we can do is randomly choose peers + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we have topology information, we see what is left to be done from peersReplicatedTo + val doneWithinRack = peersReplicatedTo.exists(_.topologyInfo == blockManagerId.topologyInfo) + val doneOutsideRack = peersReplicatedTo.exists { p => + p.topologyInfo.isDefined && p.topologyInfo != blockManagerId.topologyInfo + } + + if (doneOutsideRack && doneWithinRack) { + // we are done, we just return a random sample + BlockReplicationUtils.getRandomSample(peers, numReplicas, random) + } else { + // we separate peers within and outside rack + val (inRackPeers, outOfRackPeers) = peers + .filter(_.host != blockManagerId.host) + .partition(_.topologyInfo == blockManagerId.topologyInfo) + + val peerWithinRack = if (doneWithinRack) { + // we are done with in-rack replication, so don't need anymore peers + Seq.empty + } else { + if (inRackPeers.isEmpty) { + Seq.empty + } else { + Seq(inRackPeers(random.nextInt(inRackPeers.size))) + } + } + + val peerOutsideRack = if (doneOutsideRack || numReplicas - peerWithinRack.size <= 0) { + Seq.empty + } else { + if (outOfRackPeers.isEmpty) { + Seq.empty + } else { + Seq(outOfRackPeers(random.nextInt(outOfRackPeers.size))) + } + } + + val priorityPeers = peerWithinRack ++ peerOutsideRack + val numRemainingPeers = numReplicas - priorityPeers.size + val remainingPeers = if (numRemainingPeers > 0) { + val rPeers = peers.filter(p => !priorityPeers.contains(p)) + BlockReplicationUtils.getRandomSample(rPeers, numRemainingPeers, random) + } else { + Seq.empty + } + + (priorityPeers ++ remainingPeers).toList + } + } - // we shuffle the result to ensure a random arrangement within the sample - // to avoid any bias from set implementations - r.shuffle(indices.map(_ - 1).toList) } + } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827ae1598..eb3ff926372a2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils /** @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class DiskBlockObjectWriter( val file: File, + serializerManager: SerializerManager, serializerInstance: SerializerInstance, bufferSize: Int, - wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + bs = serializerManager.wrapStream(blockId, mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this @@ -128,16 +128,19 @@ private[spark] class DiskBlockObjectWriter( */ private def closeResources(): Unit = { if (initialized) { - mcs.manualClose() - channel = null - mcs = null - bs = null - fos = null - ts = null - objOut = null - initialized = false - streamOpen = false - hasBeenClosed = true + Utils.tryWithSafeFinally { + mcs.manualClose() + } { + channel = null + mcs = null + bs = null + fos = null + ts = null + objOut = null + initialized = false + streamOpen = false + hasBeenClosed = true + } } } @@ -199,26 +202,29 @@ private[spark] class DiskBlockObjectWriter( def revertPartialWritesAndClose(): File = { // Discard current writes. We do this by flushing the outstanding writes and then // truncating the file to its initial position. - try { + Utils.tryWithSafeFinally { if (initialized) { writeMetrics.decBytesWritten(reportedPosition - committedPosition) writeMetrics.decRecordsWritten(numRecordsWritten) streamOpen = false closeResources() } - - val truncateStream = new FileOutputStream(file, true) + } { + var truncateStream: FileOutputStream = null try { + truncateStream = new FileOutputStream(file, true) truncateStream.getChannel.truncate(committedPosition) - file + } catch { + case e: Exception => + logError("Uncaught exception while reverting partial writes to file " + file, e) } finally { - truncateStream.close() + if (truncateStream != null) { + truncateStream.close() + truncateStream = null + } } - } catch { - case e: Exception => - logError("Uncaught exception while reverting partial writes to file " + file, e) - file } + file } /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index ca23e2391ed02..c6656341fcd15 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,48 +17,67 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, IOException, RandomAccessFile} +import java.io._ import java.nio.ByteBuffer +import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel} import java.nio.channels.FileChannel.MapMode +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap -import com.google.common.io.Closeables +import scala.collection.mutable.ListBuffer -import org.apache.spark.SparkConf +import com.google.common.io.{ByteStreams, Closeables, Files} +import io.netty.channel.FileRegion +import io.netty.util.AbstractReferenceCounted + +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.security.CryptoStreamUtils +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ChunkedByteBuffer /** * Stores BlockManager blocks on disk. */ -private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) extends Logging { +private[spark] class DiskStore( + conf: SparkConf, + diskManager: DiskBlockManager, + securityManager: SecurityManager) extends Logging { private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") + private val blockSizes = new ConcurrentHashMap[String, Long]() - def getSize(blockId: BlockId): Long = { - diskManager.getFile(blockId.name).length - } + def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) /** * Invokes the provided callback function to write the specific block. * * @throws IllegalStateException if the block already exists in the disk store. */ - def put(blockId: BlockId)(writeFunc: FileOutputStream => Unit): Unit = { + def put(blockId: BlockId)(writeFunc: WritableByteChannel => Unit): Unit = { if (contains(blockId)) { throw new IllegalStateException(s"Block $blockId is already present in the disk store") } logDebug(s"Attempting to put block $blockId") val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) - val fileOutputStream = new FileOutputStream(file) + val out = new CountingWritableChannel(openForWrite(file)) var threwException: Boolean = true try { - writeFunc(fileOutputStream) + writeFunc(out) + blockSizes.put(blockId.name, out.getCount) threwException = false } finally { try { - Closeables.close(fileOutputStream, threwException) + out.close() + } catch { + case ioe: IOException => + if (!threwException) { + threwException = true + throw ioe + } } finally { if (threwException) { remove(blockId) @@ -73,41 +92,46 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e } def putBytes(blockId: BlockId, bytes: ChunkedByteBuffer): Unit = { - put(blockId) { fileOutputStream => - val channel = fileOutputStream.getChannel - Utils.tryWithSafeFinally { - bytes.writeFully(channel) - } { - channel.close() - } + put(blockId) { channel => + bytes.writeFully(channel) } } - def getBytes(blockId: BlockId): ChunkedByteBuffer = { + def getBytes(blockId: BlockId): BlockData = { val file = diskManager.getFile(blockId.name) - val channel = new RandomAccessFile(file, "r").getChannel - Utils.tryWithSafeFinally { - // For small files, directly read rather than memory map - if (file.length < minMemoryMapBytes) { - val buf = ByteBuffer.allocate(file.length.toInt) - channel.position(0) - while (buf.remaining() != 0) { - if (channel.read(buf) == -1) { - throw new IOException("Reached EOF before filling buffer\n" + - s"offset=0\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + val blockSize = getSize(blockId) + + securityManager.getIOEncryptionKey() match { + case Some(key) => + // Encrypted blocks cannot be memory mapped; return a special object that does decryption + // and provides InputStream / FileRegion implementations for reading the data. + new EncryptedBlockData(file, blockSize, conf, key) + + case _ => + val channel = new FileInputStream(file).getChannel() + if (blockSize < minMemoryMapBytes) { + // For small files, directly read rather than memory map. + Utils.tryWithSafeFinally { + val buf = ByteBuffer.allocate(blockSize.toInt) + JavaUtils.readFully(channel, buf) + buf.flip() + new ByteBufferBlockData(new ChunkedByteBuffer(buf), true) + } { + channel.close() + } + } else { + Utils.tryWithSafeFinally { + new ByteBufferBlockData( + new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)), true) + } { + channel.close() } } - buf.flip() - new ChunkedByteBuffer(buf) - } else { - new ChunkedByteBuffer(channel.map(MapMode.READ_ONLY, 0, file.length)) - } - } { - channel.close() } } def remove(blockId: BlockId): Boolean = { + blockSizes.remove(blockId.name) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() @@ -124,4 +148,142 @@ private[spark] class DiskStore(conf: SparkConf, diskManager: DiskBlockManager) e val file = diskManager.getFile(blockId.name) file.exists() } + + private def openForWrite(file: File): WritableByteChannel = { + val out = new FileOutputStream(file).getChannel() + try { + securityManager.getIOEncryptionKey().map { key => + CryptoStreamUtils.createWritableChannel(out, conf, key) + }.getOrElse(out) + } catch { + case e: Exception => + Closeables.close(out, true) + file.delete() + throw e + } + } + +} + +private class EncryptedBlockData( + file: File, + blockSize: Long, + conf: SparkConf, + key: Array[Byte]) extends BlockData { + + override def toInputStream(): InputStream = Channels.newInputStream(open()) + + override def toNetty(): Object = new ReadableChannelFileRegion(open(), blockSize) + + override def toChunkedByteBuffer(allocator: Int => ByteBuffer): ChunkedByteBuffer = { + val source = open() + try { + var remaining = blockSize + val chunks = new ListBuffer[ByteBuffer]() + while (remaining > 0) { + val chunkSize = math.min(remaining, Int.MaxValue) + val chunk = allocator(chunkSize.toInt) + remaining -= chunkSize + JavaUtils.readFully(source, chunk) + chunk.flip() + chunks += chunk + } + + new ChunkedByteBuffer(chunks.toArray) + } finally { + source.close() + } + } + + override def toByteBuffer(): ByteBuffer = { + // This is used by the block transfer service to replicate blocks. The upload code reads + // all bytes into memory to send the block to the remote executor, so it's ok to do this + // as long as the block fits in a Java array. + assert(blockSize <= Int.MaxValue, "Block is too large to be wrapped in a byte buffer.") + val dst = ByteBuffer.allocate(blockSize.toInt) + val in = open() + try { + JavaUtils.readFully(in, dst) + dst.flip() + dst + } finally { + Closeables.close(in, true) + } + } + + override def size: Long = blockSize + + override def dispose(): Unit = { } + + private def open(): ReadableByteChannel = { + val channel = new FileInputStream(file).getChannel() + try { + CryptoStreamUtils.createReadableChannel(channel, conf, key) + } catch { + case e: Exception => + Closeables.close(channel, true) + throw e + } + } + +} + +private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long) + extends AbstractReferenceCounted with FileRegion { + + private var _transferred = 0L + + private val buffer = ByteBuffer.allocateDirect(64 * 1024) + buffer.flip() + + override def count(): Long = blockSize + + override def position(): Long = 0 + + override def transfered(): Long = _transferred + + override def transferTo(target: WritableByteChannel, pos: Long): Long = { + assert(pos == transfered(), "Invalid position.") + + var written = 0L + var lastWrite = -1L + while (lastWrite != 0) { + if (!buffer.hasRemaining()) { + buffer.clear() + source.read(buffer) + buffer.flip() + } + if (buffer.hasRemaining()) { + lastWrite = target.write(buffer) + written += lastWrite + } else { + lastWrite = 0 + } + } + + _transferred += written + written + } + + override def deallocate(): Unit = source.close() +} + +private class CountingWritableChannel(sink: WritableByteChannel) extends WritableByteChannel { + + private var count = 0L + + def getCount: Long = count + + override def write(src: ByteBuffer): Int = { + val written = sink.write(src) + if (written > 0) { + count += written + } + written + } + + override def isOpen(): Boolean = sink.isOpen() + + override def close(): Unit = sink.close() + } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4dc2f362329a0..f8906117638b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,19 +17,21 @@ package org.apache.spark.storage -import java.io.InputStream +import java.io.{InputStream, IOException} +import java.nio.ByteBuffer import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy +import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.control.NonFatal import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -47,8 +49,10 @@ import org.apache.spark.util.Utils * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. + * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param detectCorrupt whether to detect any corruption in fetched blocks. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -56,8 +60,10 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, - maxReqsInFlight: Int) + maxReqsInFlight: Int, + detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator( * Current [[FetchResult]] being processed. We track this so we can release the current buffer * in case of a runtime exception when processing the current buffer. */ - @volatile private[this] var currentResult: FetchResult = null + @volatile private[this] var currentResult: SuccessFetchResult = null /** * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that @@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator( /** Current number of requests in flight */ private[this] var reqsInFlight = 0 + /** + * The blocks that can't be decompressed successfully, it is used to guarantee that we retry + * at most once for those corrupted blocks. + */ + private[this] val corruptedBlocks = mutable.HashSet[BlockId]() + private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() /** @@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator( // The currentResult is set to null to prevent releasing the buffer again on cleanup() private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary - currentResult match { - case SuccessFetchResult(_, _, _, buf, _) => buf.release() - case _ => + if (currentResult != null) { + currentResult.buf.release() } currentResult = null } @@ -247,7 +258,7 @@ final class ShuffleBlockFetcherIterator( /** * Fetch the local blocks while we are fetching remote blocks. This is ok because - * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { @@ -304,41 +315,89 @@ final class ShuffleBlockFetcherIterator( * Throws a FetchFailedException if the next block could not be fetched. */ override def next(): (BlockId, InputStream) = { - numBlocksProcessed += 1 - val startFetchWait = System.currentTimeMillis() - currentResult = results.take() - val result = currentResult - val stopFetchWait = System.currentTimeMillis() - shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) - - result match { - case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => - if (address != blockManager.blockManagerId) { - shuffleMetrics.incRemoteBytesRead(buf.size) - shuffleMetrics.incRemoteBlocksFetched(1) - } - bytesInFlight -= size - if (isNetworkReqDone) { - reqsInFlight -= 1 - logDebug("Number of requests in flight " + reqsInFlight) - } - case _ => + if (!hasNext) { + throw new NoSuchElementException } - // Send fetch requests up to maxBytesInFlight - fetchUpToMaxBytes() - result match { - case FailureFetchResult(blockId, address, e) => - throwFetchFailedException(blockId, address, e) + numBlocksProcessed += 1 - case SuccessFetchResult(blockId, address, _, buf, _) => - try { - (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) - } catch { - case NonFatal(t) => - throwFetchFailedException(blockId, address, t) - } + var result: FetchResult = null + var input: InputStream = null + // Take the next fetched result and try to decompress it to detect data corruption, + // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch + // is also corrupt, so the previous stage could be retried. + // For local shuffle block, throw FailureFetchResult for the first IOException. + while (result == null) { + val startFetchWait = System.currentTimeMillis() + result = results.take() + val stopFetchWait = System.currentTimeMillis() + shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) + + result match { + case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => + if (address != blockManager.blockManagerId) { + shuffleMetrics.incRemoteBytesRead(buf.size) + shuffleMetrics.incRemoteBlocksFetched(1) + } + bytesInFlight -= size + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + + val in = try { + buf.createInputStream() + } catch { + // The exception could only be throwed by local shuffle block + case e: IOException => + assert(buf.isInstanceOf[FileSegmentManagedBuffer]) + logError("Failed to create input stream from local block", e) + buf.release() + throwFetchFailedException(blockId, address, e) + } + + input = streamWrapper(blockId, in) + // Only copy the stream if it's wrapped by compression or encryption, also the size of + // block is small (the decompressed block is smaller than maxBytesInFlight) + if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + val originalInput = input + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + try { + // Decompress the whole block at once to detect any corruption, which could increase + // the memory usage tne potential increase the chance of OOM. + // TODO: manage the memory used here, and spill it into disk in case of OOM. + Utils.copyStream(input, out) + out.close() + input = out.toChunkedByteBuffer.toInputStream(dispose = true) + } catch { + case e: IOException => + buf.release() + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest(address, Array((blockId, size))) + result = null + } + } finally { + // TODO: release the buf here to free memory earlier + originalInput.close() + in.close() + } + } + + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + } + + // Send fetch requests up to maxBytesInFlight + fetchUpToMaxBytes() } + + currentResult = result.asInstanceOf[SuccessFetchResult] + (currentResult.blockId, new BufferReleasingInputStream(input, this)) } private def fetchUpToMaxBytes(): Unit = { @@ -423,7 +482,7 @@ object ShuffleBlockFetcherIterator { * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. - * @param buf [[ManagedBuffer]] for the content. + * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index fad0404bebc36..4c6998d7a8e20 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils * ExternalBlockStore, whether to keep the data in memory in a serialized format, and whether * to replicate the RDD partitions on multiple nodes. * - * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants + * The [[org.apache.spark.storage.StorageLevel]] singleton object contains some static constants * for commonly useful storage levels. To create your own storage level object, use the * factory method of the singleton object (`StorageLevel(...)`). */ diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 798658a15b797..ac60f795915a3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -30,6 +30,7 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() @@ -41,7 +42,7 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { } def deadStorageStatusList: Seq[StorageStatus] = synchronized { - deadExecutorStorageStatus.toSeq + deadExecutorStorageStatus } /** Update storage status list to reflect updated block statuses */ @@ -74,8 +75,10 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener { synchronized { val blockManagerId = blockManagerAdded.blockManagerId val executorId = blockManagerId.executorId - val maxMem = blockManagerAdded.maxMem - val storageStatus = new StorageStatus(blockManagerId, maxMem) + // The onHeap and offHeap memory are always defined for new applications, + // but they can be missing if we are replaying old event logs. + val storageStatus = new StorageStatus(blockManagerId, blockManagerAdded.maxMem, + blockManagerAdded.maxOnHeapMem, blockManagerAdded.maxOffHeapMem) executorIdToStorageStatus(executorId) = storageStatus // Try to remove the dead storage status if same executor register the block manager twice. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index fb9941bbd9e0f..e9694fdbca2de 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,7 +35,12 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi -class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { +@deprecated("This class may be removed or made private in a future release.", "2.2.0") +class StorageStatus( + val blockManagerId: BlockManagerId, + val maxMemory: Long, + val maxOnHeapMem: Option[Long], + val maxOffHeapMem: Option[Long]) { /** * Internal representation of the blocks stored in this block manager. @@ -46,32 +51,28 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { private val _rddBlocks = new mutable.HashMap[Int, mutable.Map[BlockId, BlockStatus]] private val _nonRddBlocks = new mutable.HashMap[BlockId, BlockStatus] - /** - * Storage information of the blocks that entails memory, disk, and off-heap memory usage. - * - * As with the block maps, we store the storage information separately for RDD blocks and - * non-RDD blocks for the same reason. In particular, RDD storage information is stored - * in a map indexed by the RDD ID to the following 4-tuple: - * - * (memory size, disk size, storage level) - * - * We assume that all the blocks that belong to the same RDD have the same storage level. - * This field is not relevant to non-RDD blocks, however, so the storage information for - * non-RDD blocks contains only the first 3 fields (in the same order). - */ - private val _rddStorageInfo = new mutable.HashMap[Int, (Long, Long, StorageLevel)] - private var _nonRddStorageInfo: (Long, Long) = (0L, 0L) + private case class RddStorageInfo(memoryUsage: Long, diskUsage: Long, level: StorageLevel) + private val _rddStorageInfo = new mutable.HashMap[Int, RddStorageInfo] + + private case class NonRddStorageInfo(var onHeapUsage: Long, var offHeapUsage: Long, + var diskUsage: Long) + private val _nonRddStorageInfo = NonRddStorageInfo(0L, 0L, 0L) /** Create a storage status with an initial set of blocks, leaving the source unmodified. */ - def this(bmid: BlockManagerId, maxMem: Long, initialBlocks: Map[BlockId, BlockStatus]) { - this(bmid, maxMem) + def this( + bmid: BlockManagerId, + maxMemory: Long, + maxOnHeapMem: Option[Long], + maxOffHeapMem: Option[Long], + initialBlocks: Map[BlockId, BlockStatus]) { + this(bmid, maxMemory, maxOnHeapMem, maxOffHeapMem) initialBlocks.foreach { case (bid, bstatus) => addBlock(bid, bstatus) } } /** * Return the blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * contains, get, and size. */ @@ -80,7 +81,7 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the RDD blocks stored in this block manager. * - * Note that this is somewhat expensive, as it involves cloning the underlying maps and then + * @note This is somewhat expensive, as it involves cloning the underlying maps and then * concatenating them together. Much faster alternatives exist for common operations such as * getting the memory, disk, and off-heap memory sizes occupied by this RDD. */ @@ -128,7 +129,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return whether the given block is stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.contains`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.contains`, which is O(blocks) time. */ def containsBlock(blockId: BlockId): Boolean = { blockId match { @@ -141,7 +143,8 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the given block stored in this block manager in O(1) time. - * Note that this is much faster than `this.blocks.get`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.get`, which is O(blocks) time. */ def getBlock(blockId: BlockId): Option[BlockStatus] = { blockId match { @@ -154,43 +157,77 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { /** * Return the number of blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.blocks.size`, which is O(blocks) time. + * + * @note This is much faster than `this.blocks.size`, which is O(blocks) time. */ def numBlocks: Int = _nonRddBlocks.size + numRddBlocks /** * Return the number of RDD blocks stored in this block manager in O(RDDs) time. - * Note that this is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. + * + * @note This is much faster than `this.rddBlocks.size`, which is O(RDD blocks) time. */ def numRddBlocks: Int = _rddBlocks.values.map(_.size).sum /** * Return the number of blocks that belong to the given RDD in O(1) time. - * Note that this is much faster than `this.rddBlocksById(rddId).size`, which is + * + * @note This is much faster than `this.rddBlocksById(rddId).size`, which is * O(blocks in this RDD) time. */ def numRddBlocksById(rddId: Int): Int = _rddBlocks.get(rddId).map(_.size).getOrElse(0) + /** Return the max memory can be used by this block manager. */ + def maxMem: Long = maxMemory + /** Return the memory remaining in this block manager. */ def memRemaining: Long = maxMem - memUsed + /** Return the memory used by caching RDDs */ + def cacheSize: Long = onHeapCacheSize.getOrElse(0L) + offHeapCacheSize.getOrElse(0L) + /** Return the memory used by this block manager. */ - def memUsed: Long = _nonRddStorageInfo._1 + cacheSize + def memUsed: Long = onHeapMemUsed.getOrElse(0L) + offHeapMemUsed.getOrElse(0L) - /** Return the memory used by caching RDDs */ - def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + /** Return the on-heap memory remaining in this block manager. */ + def onHeapMemRemaining: Option[Long] = + for (m <- maxOnHeapMem; o <- onHeapMemUsed) yield m - o + + /** Return the off-heap memory remaining in this block manager. */ + def offHeapMemRemaining: Option[Long] = + for (m <- maxOffHeapMem; o <- offHeapMemUsed) yield m - o + + /** Return the on-heap memory used by this block manager. */ + def onHeapMemUsed: Option[Long] = onHeapCacheSize.map(_ + _nonRddStorageInfo.onHeapUsage) + + /** Return the off-heap memory used by this block manager. */ + def offHeapMemUsed: Option[Long] = offHeapCacheSize.map(_ + _nonRddStorageInfo.offHeapUsage) + + /** Return the memory used by on-heap caching RDDs */ + def onHeapCacheSize: Option[Long] = maxOnHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if !storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } + + /** Return the memory used by off-heap caching RDDs */ + def offHeapCacheSize: Option[Long] = maxOffHeapMem.map { _ => + _rddStorageInfo.collect { + case (_, storageInfo) if storageInfo.level.useOffHeap => storageInfo.memoryUsage + }.sum + } /** Return the disk space used by this block manager. */ - def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo.diskUsage + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ - def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) + def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.memoryUsage).getOrElse(0L) /** Return the disk space used by the given RDD in this block manager in O(1) time. */ - def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._2).getOrElse(0L) + def diskUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_.diskUsage).getOrElse(0L) /** Return the storage level, if any, used by the given RDD in this block manager. */ - def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_._3) + def rddStorageLevel(rddId: Int): Option[StorageLevel] = _rddStorageInfo.get(rddId).map(_.level) /** * Update the relevant storage info, taking into account any existing status for this block. @@ -205,10 +242,12 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { val (oldMem, oldDisk) = blockId match { case RDDBlockId(rddId, _) => _rddStorageInfo.get(rddId) - .map { case (mem, disk, _) => (mem, disk) } + .map { case RddStorageInfo(mem, disk, _) => (mem, disk) } .getOrElse((0L, 0L)) - case _ => - _nonRddStorageInfo + case _ if !level.useOffHeap => + (_nonRddStorageInfo.onHeapUsage, _nonRddStorageInfo.diskUsage) + case _ if level.useOffHeap => + (_nonRddStorageInfo.offHeapUsage, _nonRddStorageInfo.diskUsage) } val newMem = math.max(oldMem + changeInMem, 0L) val newDisk = math.max(oldDisk + changeInDisk, 0L) @@ -220,30 +259,40 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { if (newMem + newDisk == 0) { _rddStorageInfo.remove(rddId) } else { - _rddStorageInfo(rddId) = (newMem, newDisk, level) + _rddStorageInfo(rddId) = RddStorageInfo(newMem, newDisk, level) } case _ => - _nonRddStorageInfo = (newMem, newDisk) + if (!level.useOffHeap) { + _nonRddStorageInfo.onHeapUsage = newMem + } else { + _nonRddStorageInfo.offHeapUsage = newMem + } + _nonRddStorageInfo.diskUsage = newDisk } } - } /** Helper methods for storage-related objects. */ private[spark] object StorageUtils extends Logging { - /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up a ByteBuffer if it is direct or memory-mapped. This uses an *unsafe* Sun + * API that will cause errors if one attempts to read from the disposed buffer. However, neither + * the bytes allocated to direct buffers nor file descriptors opened for memory-mapped buffers put + * pressure on the garbage collector. Waiting for garbage collection may lead to the depletion of + * off-heap memory or huge numbers of open files. There's unfortunately no standard API to + * manually dispose of these kinds of buffers. */ def dispose(buffer: ByteBuffer): Unit = { if (buffer != null && buffer.isInstanceOf[MappedByteBuffer]) { - logTrace(s"Unmapping $buffer") - if (buffer.asInstanceOf[DirectBuffer].cleaner() != null) { - buffer.asInstanceOf[DirectBuffer].cleaner().clean() - } + logTrace(s"Disposing of $buffer") + cleanDirectBuffer(buffer.asInstanceOf[DirectBuffer]) + } + } + + private def cleanDirectBuffer(buffer: DirectBuffer) = { + val cleaner = buffer.cleaner() + if (cleaner != null) { + cleaner.clean() } } diff --git a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala index a0f0fdef8e948..a150a8e3636e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala +++ b/core/src/main/scala/org/apache/spark/storage/TopologyMapper.scala @@ -60,7 +60,7 @@ class DefaultTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with L /** * A simple file based topology mapper. This expects topology information provided as a - * [[java.util.Properties]] file. The name of the file is obtained from SparkConf property + * `java.util.Properties` file. The name of the file is obtained from SparkConf property * `spark.storage.replication.topologyFile`. To use this topology mapper, set the * `spark.storage.replication.topologyMapper` property to * [[org.apache.spark.storage.FileBasedTopologyMapper]] diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 095d32407f345..90e3af2d0ec74 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -31,7 +31,7 @@ import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} -import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} +import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} import org.apache.spark.unsafe.Platform import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -331,11 +331,20 @@ private[spark] class MemoryStore( var unrollMemoryUsedByThisBlock = 0L // Underlying buffer for unrolling the block val redirectableStream = new RedirectableOutputStream - val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) + val chunkSize = if (initialMemoryThreshold > Int.MaxValue) { + logWarning(s"Initial memory threshold of ${Utils.bytesToString(initialMemoryThreshold)} " + + s"is too large to be set as chunk size. Chunk size has been capped to " + + s"${Utils.bytesToString(Int.MaxValue)}") + Int.MaxValue + } else { + initialMemoryThreshold.toInt + } + val bbos = new ChunkedByteBufferOutputStream(chunkSize, allocator) redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { - val ser = serializerManager.getSerializer(classTag).newInstance() - ser.serializeStream(serializerManager.wrapStream(blockId, redirectableStream)) + val autoPick = !blockId.isInstanceOf[StreamBlockId] + val ser = serializerManager.getSerializer(classTag, autoPick).newInstance() + ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling @@ -693,7 +702,7 @@ private[storage] class PartiallyUnrolledIterator[T]( } override def next(): T = { - if (unrolled == null) { + if (unrolled == null || !unrolled.hasNext) { rest.next() } else { unrolled.next() diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 35c3c8d00f99b..edf328b5ae538 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -27,10 +27,10 @@ import scala.xml.Node import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet -import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector} +import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.handler.gzip.GzipHandler import org.eclipse.jetty.servlet._ -import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue @@ -45,6 +45,9 @@ import org.apache.spark.util.Utils */ private[spark] object JettyUtils extends Logging { + val SPARK_CONNECTOR_NAME = "Spark" + val REDIRECT_CONNECTOR_NAME = "HttpsRedirect" + // Base type for a function that returns something based on an HTTP request. Allows for // implicit conversion from many types of functions to jetty Handlers. type Responder[T] = HttpServletRequest => T @@ -87,9 +90,9 @@ private[spark] object JettyUtils extends Logging { response.setHeader("X-Frame-Options", xFrameOptionsValue) response.getWriter.print(servletParams.extractFn(result)) } else { - response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setStatus(HttpServletResponse.SC_FORBIDDEN) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + response.sendError(HttpServletResponse.SC_FORBIDDEN, "User is not authorized to access this page.") } } catch { @@ -274,95 +277,127 @@ private[spark] object JettyUtils extends Logging { conf: SparkConf, serverName: String = ""): ServerInfo = { - val collection = new ContextHandlerCollection addFilters(handlers, conf) - val gzipHandlers = handlers.map { h => - val gzipHandler = new GzipHandler - gzipHandler.setHandler(h) - gzipHandler + // Start the server first, with no connectors. + val pool = new QueuedThreadPool + if (serverName.nonEmpty) { + pool.setName(serverName) } + pool.setDaemon(true) - // Bind to the given port, or throw a java.net.BindException if the port is occupied - def connect(currentPort: Int): (Server, Int) = { - val pool = new QueuedThreadPool - if (serverName.nonEmpty) { - pool.setName(serverName) - } - pool.setDaemon(true) - - val server = new Server(pool) - val connectors = new ArrayBuffer[ServerConnector] - // Create a connector on port currentPort to listen for HTTP requests - val httpConnector = new ServerConnector( - server, - null, - // Call this full constructor to set this, which forces daemon threads: - new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true), - null, - -1, - -1, - new HttpConnectionFactory()) - httpConnector.setPort(currentPort) - connectors += httpConnector - - sslOptions.createJettySslContextFactory().foreach { factory => - // If the new port wraps around, do not try a privileged port. - val securePort = - if (currentPort != 0) { - (currentPort + 400 - 1024) % (65536 - 1024) + 1024 - } else { - 0 - } - val scheme = "https" - // Create a connector on port securePort to listen for HTTPS requests - val connector = new ServerConnector(server, factory) - connector.setPort(securePort) + val server = new Server(pool) - connectors += connector + val errorHandler = new ErrorHandler() + errorHandler.setShowStacks(true) + errorHandler.setServer(server) + server.addBean(errorHandler) - // redirect the HTTP requests to HTTPS port - collection.addHandler(createRedirectHttpsHandler(securePort, scheme)) - } + val collection = new ContextHandlerCollection + server.setHandler(collection) + + // Executor used to create daemon threads for the Jetty connectors. + val serverExecutor = new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true) + + try { + server.start() - gzipHandlers.foreach(collection.addHandler) // As each acceptor and each selector will use one thread, the number of threads should at // least be the number of acceptors and selectors plus 1. (See SPARK-13776) var minThreads = 1 - connectors.foreach { connector => + + def newConnector( + connectionFactories: Array[ConnectionFactory], + port: Int): (ServerConnector, Int) = { + val connector = new ServerConnector( + server, + null, + serverExecutor, + null, + -1, + -1, + connectionFactories: _*) + connector.setPort(port) + connector.start() + // Currently we only use "SelectChannelConnector" // Limit the max acceptor number to 8 so that we don't waste a lot of threads connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8)) connector.setHost(hostName) // The number of selectors always equals to the number of acceptors minThreads += connector.getAcceptors * 2 + + (connector, connector.getLocalPort()) } - server.setConnectors(connectors.toArray) - pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) - val errorHandler = new ErrorHandler() - errorHandler.setShowStacks(true) - errorHandler.setServer(server) - server.addBean(errorHandler) - server.setHandler(collection) - try { - server.start() - (server, httpConnector.getLocalPort) - } catch { - case e: Exception => - server.stop() - pool.stop() - throw e + // If SSL is configured, create the secure connector first. + val securePort = sslOptions.createJettySslContextFactory().map { factory => + val securePort = sslOptions.port.getOrElse(if (port > 0) Utils.userPort(port, 400) else 0) + val secureServerName = if (serverName.nonEmpty) s"$serverName (HTTPS)" else serverName + val connectionFactories = AbstractConnectionFactory.getFactories(factory, + new HttpConnectionFactory()) + + def sslConnect(currentPort: Int): (ServerConnector, Int) = { + newConnector(connectionFactories, currentPort) + } + + val (connector, boundPort) = Utils.startServiceOnPort[ServerConnector](securePort, + sslConnect, conf, secureServerName) + connector.setName(SPARK_CONNECTOR_NAME) + server.addConnector(connector) + boundPort + } + + // Bind the HTTP port. + def httpConnect(currentPort: Int): (ServerConnector, Int) = { + newConnector(Array(new HttpConnectionFactory()), currentPort) + } + + val (httpConnector, httpPort) = Utils.startServiceOnPort[ServerConnector](port, httpConnect, + conf, serverName) + + // If SSL is configured, then configure redirection in the HTTP connector. + securePort match { + case Some(p) => + httpConnector.setName(REDIRECT_CONNECTOR_NAME) + val redirector = createRedirectHttpsHandler(p, "https") + collection.addHandler(redirector) + redirector.start() + + case None => + httpConnector.setName(SPARK_CONNECTOR_NAME) } - } - val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName) - ServerInfo(server, boundPort, collection) + server.addConnector(httpConnector) + + // Add all the known handlers now that connectors are configured. + handlers.foreach { h => + h.setVirtualHosts(toVirtualHosts(SPARK_CONNECTOR_NAME)) + val gzipHandler = new GzipHandler() + gzipHandler.setHandler(h) + collection.addHandler(gzipHandler) + gzipHandler.start() + } + + pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads)) + ServerInfo(server, httpPort, securePort, collection) + } catch { + case e: Exception => + server.stop() + if (serverExecutor.isStarted()) { + serverExecutor.stop() + } + if (pool.isStarted()) { + pool.stop() + } + throw e + } } private def createRedirectHttpsHandler(securePort: Int, scheme: String): ContextHandler = { val redirectHandler: ContextHandler = new ContextHandler redirectHandler.setContextPath("/") + redirectHandler.setVirtualHosts(toVirtualHosts(REDIRECT_CONNECTOR_NAME)) redirectHandler.setHandler(new AbstractHandler { override def handle( target: String, @@ -375,8 +410,7 @@ private[spark] object JettyUtils extends Logging { val httpsURI = createRedirectURI(scheme, baseRequest.getServerName, securePort, baseRequest.getRequestURI, baseRequest.getQueryString) response.setContentLength(0) - response.encodeRedirectURL(httpsURI) - response.sendRedirect(httpsURI) + response.sendRedirect(response.encodeRedirectURL(httpsURI)) baseRequest.setHandled(true) } }) @@ -437,12 +471,30 @@ private[spark] object JettyUtils extends Logging { new URI(scheme, authority, path, query, null).toString } + def toVirtualHosts(connectors: String*): Array[String] = connectors.map("@" + _).toArray + } private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) { + securePort: Option[Int], + private val rootHandler: ContextHandlerCollection) { + + def addHandler(handler: ContextHandler): Unit = { + handler.setVirtualHosts(JettyUtils.toVirtualHosts(JettyUtils.SPARK_CONNECTOR_NAME)) + rootHandler.addHandler(handler) + if (!handler.isStarted()) { + handler.start() + } + } + + def removeHandler(handler: ContextHandler): Unit = { + rootHandler.removeHandler(handler) + if (handler.isStarted) { + handler.stop() + } + } def stop(): Unit = { server.stop() diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 2a7c16b04bf7f..79974df2603fd 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -175,13 +175,14 @@ private[ui] trait PagedTable[T] { val hiddenFormFields = { if (goButtonFormPath.contains('?')) { - val querystring = goButtonFormPath.split("\\?", 2)(1) + val queryString = goButtonFormPath.split("\\?", 2)(1) + val search = queryString.split("#")(0) Splitter .on('&') .trimResults() .omitEmptyStrings() .withKeyValueSeparator("=") - .split(querystring) + .split(search) .asScala .filterKeys(_ != pageSizeFormField) .filterKeys(_ != prevPageSizeFormField) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index f631a047a707d..f271c56021e95 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -60,6 +60,10 @@ private[spark] class SparkUI private ( var appId: String = _ + var appSparkVersion = org.apache.spark.SPARK_VERSION + + private var streamingJobProgressListener: Option[SparkListener] = None + /** Initialize all components of the server. */ def initialize() { val jobsTab = new JobsTab(this) @@ -82,7 +86,7 @@ private[spark] class SparkUI private ( initialize() def getSparkUser: String = { - environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + environmentListener.systemProperties.toMap.getOrElse("user.name", "") } def getAppName: String = appName @@ -94,16 +98,9 @@ private[spark] class SparkUI private ( /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() - logInfo("Stopped Spark web UI at %s".format(appUIAddress)) + logInfo(s"Stopped Spark web UI at $webUrl") } - /** - * Return the application UI host:port. This does not include the scheme (http://). - */ - private[spark] def appUIHostPort = publicHostName + ":" + boundPort - - private[spark] def appUIAddress = s"http://$appUIHostPort" - def getSparkUI(appId: String): Option[SparkUI] = { if (appId == this.appId) Some(this) else None } @@ -122,8 +119,9 @@ private[spark] class SparkUI private ( endTime = new Date(-1), duration = 0, lastUpdated = new Date(startTime), - sparkUser = "", - completed = false + sparkUser = getSparkUser, + completed = false, + appSparkVersion = appSparkVersion )) )) } @@ -131,13 +129,20 @@ private[spark] class SparkUI private ( def getApplicationInfo(appId: String): Option[ApplicationInfo] = { getApplicationInfoList.find(_.id == appId) } + + def getStreamingJobProgressListener: Option[SparkListener] = streamingJobProgressListener + + def setStreamingJobProgressListener(sparkListener: SparkListener): Unit = { + streamingJobProgressListener = Option(sparkListener) + } } private[spark] abstract class SparkUITab(parent: SparkUI, prefix: String) extends WebUITab(parent, prefix) { - def appName: String = parent.getAppName + def appName: String = parent.appName + def appSparkVersion: String = parent.appSparkVersion } private[spark] object SparkUI { diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 3cc5353f475f4..766cc65084f07 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -91,6 +91,9 @@ private[spark] object ToolTips { val TASK_TIME = "Shaded red when garbage collection (GC) time is over 10% of task time" + val BLACKLISTED = + "Shows if this executor has been blacklisted by the scheduler due to task failures." + val APPLICATION_EXECUTOR_LIMIT = """Maximum number of executors that this application will use. This limit is finite only when dynamic allocation is enabled. The number of granted executors may exceed the limit diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index c0d1a2220f62a..2610f673d27f6 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import scala.xml._ import scala.xml.transform.{RewriteRule, RuleTransformer} +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.internal.Logging import org.apache.spark.ui.scope.RDDOperationGraph @@ -34,9 +36,12 @@ private[spark] object UIUtils extends Logging { val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable" + private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + override def initialValue(): SimpleDateFormat = + new SimpleDateFormat("yyyy/MM/dd HH:mm:ss", Locale.US) } def formatDate(date: Date): String = dateFormat.get.format(date) @@ -170,6 +175,7 @@ private[spark] object UIUtils extends Logging { + } def vizHeaderNodes: Seq[Node] = { @@ -226,7 +232,7 @@ private[spark] object UIUtils extends Logging {
    @@ -420,8 +429,8 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like ++ ++ } -
    ; +
    UIUtils.headerSparkPage("Executors", content, parent, useDataTables = true) } } private[spark] object ExecutorsPage { + private val ON_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for on heap " + + "storage of data like RDD partitions cached in memory." + private val OFF_HEAP_MEMORY_TOOLTIP = "Memory used / total available memory for off heap " + + "storage of data like RDD partitions cached in memory." + /** Represent an executor's info as a map given a storage status index */ def getExecInfo( listener: ExecutorsListener, @@ -80,6 +114,16 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem + val memoryMetrics = for { + onHeapUsed <- status.onHeapMemUsed + offHeapUsed <- status.offHeapMemUsed + maxOnHeap <- status.maxOnHeapMem + maxOffHeap <- status.maxOffHeapMem + } yield { + new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) + } + + val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -101,8 +145,10 @@ private[spark] object ExecutorsPage { taskSummary.inputBytes, taskSummary.shuffleRead, taskSummary.shuffleWrite, + taskSummary.isBlacklisted, maxMem, - taskSummary.executorLogs + taskSummary.executorLogs, + memoryMetrics ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 678571fd4f5ac..aabf6e0c63c02 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -53,7 +53,8 @@ private[ui] case class ExecutorTaskSummary( var shuffleRead: Long = 0L, var shuffleWrite: Long = 0L, var executorLogs: Map[String, String] = Map.empty, - var isAlive: Boolean = true + var isAlive: Boolean = true, + var isBlacklisted: Boolean = false ) /** @@ -61,9 +62,10 @@ private[ui] case class ExecutorTaskSummary( * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { - var executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() + val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() var executorEvents = new ListBuffer[SparkListenerEvent]() private val maxTimelineExecutors = conf.getInt("spark.ui.timeline.executors.maximum", 1000) @@ -73,7 +75,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar def deadStorageStatusList: Seq[StorageStatus] = storageStatusListener.deadStorageStatusList - override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = synchronized { + override def onExecutorAdded( + executorAdded: SparkListenerExecutorAdded): Unit = synchronized { val eid = executorAdded.executorId val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) taskSummary.executorLogs = executorAdded.executorInfo.logUrlMap @@ -100,7 +103,8 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar executorToTaskSummary.get(executorRemoved.executorId).foreach(e => e.isAlive = false) } - override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { + override def onApplicationStart( + applicationStart: SparkListenerApplicationStart): Unit = { applicationStart.driverLogs.foreach { logs => val storageStatus = activeStorageStatusList.find { s => s.blockManagerId.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER || @@ -114,13 +118,15 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar } } - override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { + override def onTaskStart( + taskStart: SparkListenerTaskStart): Unit = synchronized { val eid = taskStart.taskInfo.executorId val taskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) taskSummary.tasksActive += 1 } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + override def onTaskEnd( + taskEnd: SparkListenerTaskEnd): Unit = synchronized { val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId @@ -132,7 +138,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // could have failed half-way through. The correct fix would be to keep track of the // metrics added by each attempt, but this is much more complicated. return - case e: ExceptionFailure => + case _: ExceptionFailure => taskSummary.tasksFailed += 1 case _ => taskSummary.tasksComplete += 1 @@ -157,4 +163,46 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar } } + private def updateExecutorBlacklist( + eid: String, + isBlacklisted: Boolean): Unit = { + val execTaskSummary = executorToTaskSummary.getOrElseUpdate(eid, ExecutorTaskSummary(eid)) + execTaskSummary.isBlacklisted = isBlacklisted + } + + override def onExecutorBlacklisted( + executorBlacklisted: SparkListenerExecutorBlacklisted) + : Unit = synchronized { + updateExecutorBlacklist(executorBlacklisted.executorId, true) + } + + override def onExecutorUnblacklisted( + executorUnblacklisted: SparkListenerExecutorUnblacklisted) + : Unit = synchronized { + updateExecutorBlacklist(executorUnblacklisted.executorId, false) + } + + override def onNodeBlacklisted( + nodeBlacklisted: SparkListenerNodeBlacklisted) + : Unit = synchronized { + // Implicitly blacklist every executor associated with this node, and show this in the UI. + activeStorageStatusList.foreach { status => + if (status.blockManagerId.host == nodeBlacklisted.hostId) { + updateExecutorBlacklist(status.blockManagerId.executorId, true) + } + } + } + + override def onNodeUnblacklisted( + nodeUnblacklisted: SparkListenerNodeUnblacklisted) + : Unit = synchronized { + // Implicitly unblacklist every executor associated with this node, regardless of how + // they may have been blacklisted initially (either explicitly through executor blacklisting + // or implicitly through node blacklisting). Show this in the UI. + activeStorageStatusList.foreach { status => + if (status.blockManagerId.host == nodeUnblacklisted.hostId) { + updateExecutorBlacklist(status.blockManagerId.executorId, false) + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index 173fc3cf31ce8..a0fd29c22ddca 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { jobTag: String, jobs: Seq[JobUIData], killEnabled: Boolean): Seq[Node] = { - val allParameters = request.getParameterMap.asScala.toMap + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag)) .map(para => para._1 + "=" + para._2(0)) val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id" - val parameterJobPage = request.getParameter(jobTag + ".page") - val parameterJobSortColumn = request.getParameter(jobTag + ".sort") - val parameterJobSortDesc = request.getParameter(jobTag + ".desc") - val parameterJobPageSize = request.getParameter(jobTag + ".pageSize") - val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page")) + val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort")) + val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc")) + val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize")) + val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize")) val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1) val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn => @@ -289,8 +291,8 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val startTime = listener.startTime val endTime = listener.endTime val activeJobs = listener.activeJobs.values.toSeq - val completedJobs = listener.completedJobs.reverse.toSeq - val failedJobs = listener.failedJobs.reverse.toSeq + val completedJobs = listener.completedJobs.reverse + val failedJobs = listener.failedJobs.reverse val activeJobsTable = jobsTable(request, "active", "activeJob", activeJobs, killEnabled = parent.killEnabled) @@ -500,7 +502,8 @@ private[ui] class JobPagedTable( override def tableId: String = jobTag + "-table" override def tableCssClass: String = - "table table-bordered table-condensed table-striped table-head-clickable" + "table table-bordered table-condensed table-striped " + + "table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = jobTag + ".pageSize" @@ -629,8 +632,8 @@ private[ui] class JobPagedTable( {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, - failed = job.numFailedTasks, skipped = job.numSkippedTasks, killed = job.numKilledTasks, - total = job.numTasks - job.numSkippedTasks)} + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + reasonToNumKilled = job.reasonToNumKilled, total = job.numTasks - job.numSkippedTasks)} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index fe6ca1099e6b0..2b0816e35747d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -34,9 +34,9 @@ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { listener.synchronized { val activeStages = listener.activeStages.values.toSeq val pendingStages = listener.pendingStages.values.toSeq - val completedStages = listener.completedStages.reverse.toSeq + val completedStages = listener.completedStages.reverse val numCompletedStages = listener.numCompletedStages - val failedStages = listener.failedStages.reverse.toSeq + val failedStages = listener.failedStages.reverse val numFailedStages = listener.numFailedStages val subPath = "stages" diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 9fb3f35fd9685..382a6f979f2e6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -85,6 +85,11 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage Shuffle Spill (Memory) Shuffle Spill (Disk) }} + + + Blacklisted + + {createExecutorTable()} @@ -128,9 +133,9 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} {UIUtils.formatDuration(v.taskTime)} - {v.failedTasks + v.succeededTasks + v.killedTasks} + {v.failedTasks + v.succeededTasks + v.reasonToNumKilled.values.sum} {v.failedTasks} - {v.killedTasks} + {v.reasonToNumKilled.values.sum} {v.succeededTasks} {if (stageData.hasInput) { @@ -160,6 +165,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage {Utils.bytesToString(v.diskBytesSpilled)} }} + {v.isBlacklisted} } case None => diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 0ff9e5e9411ca..9fb011a049b7e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import java.util.Date +import java.util.{Date, Locale} import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, ListBuffer} @@ -77,7 +77,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'content': '
    retainedStages) { - val toRemove = (stages.size - retainedStages) + val toRemove = calculateNumberToRemove(stages.size, retainedStages) stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) stageIdToInfo.remove(s.stageId) @@ -154,7 +155,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { /** If jobs is too large, remove and garbage collect old jobs */ private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized { if (jobs.size > retainedJobs) { - val toRemove = (jobs.size - retainedJobs) + val toRemove = calculateNumberToRemove(jobs.size, retainedJobs) jobs.take(toRemove).foreach { job => // Remove the job's UI data, if it exists jobIdToData.remove(job.jobId).foreach { removedJob => @@ -226,7 +227,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { trimJobsIfNecessary(completedJobs) jobData.status = JobExecutionStatus.SUCCEEDED numCompletedJobs += 1 - case JobFailed(exception) => + case JobFailed(_) => failedJobs += jobData trimJobsIfNecessary(failedJobs) jobData.status = JobExecutionStatus.FAILED @@ -284,7 +285,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { ) { jobData.numActiveStages -= 1 if (stage.failureReason.isEmpty) { - if (!stage.submissionTime.isEmpty) { + if (stage.submissionTime.isDefined) { jobData.completedStageIndices.add(stage.stageId) } } else { @@ -371,8 +372,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => execSummary.succeededTasks += 1 - case TaskKilled => - execSummary.killedTasks += 1 + case kill: TaskKilled => + execSummary.reasonToNumKilled = execSummary.reasonToNumKilled.updated( + kill.reason, execSummary.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => execSummary.failedTasks += 1 } @@ -385,9 +387,10 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.completedIndices.add(info.index) stageData.numCompleteTasks += 1 None - case TaskKilled => - stageData.numKilledTasks += 1 - Some(TaskKilled.toErrorString) + case kill: TaskKilled => + stageData.reasonToNumKilled = stageData.reasonToNumKilled.updated( + kill.reason, stageData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) + Some(kill.toErrorString) case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates stageData.numFailedTasks += 1 Some(e.toErrorString) @@ -409,7 +412,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // If Tasks is too large, remove and garbage collect old tasks if (stageData.taskData.size > retainedTasks) { - stageData.taskData = stageData.taskData.drop(stageData.taskData.size - retainedTasks) + stageData.taskData = stageData.taskData.drop( + calculateNumberToRemove(stageData.taskData.size, retainedTasks)) } for ( @@ -421,8 +425,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskEnd.reason match { case Success => jobData.numCompletedTasks += 1 - case TaskKilled => - jobData.numKilledTasks += 1 + case kill: TaskKilled => + jobData.reasonToNumKilled = jobData.reasonToNumKilled.updated( + kill.reason, jobData.reasonToNumKilled.getOrElse(kill.reason, 0) + 1) case _ => jobData.numFailedTasks += 1 } @@ -430,6 +435,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { } } + /** + * Remove at least (maxRetained / 10) items to reduce friction. + */ + private def calculateNumberToRemove(dataSize: Int, retainedSize: Int): Int = { + math.max(retainedSize / 10, dataSize - retainedSize) + } + /** * Upon receiving new metrics for a task, updates the per-stage and per-executor-per-stage * aggregate metrics by calculating deltas between the currently recorded metrics and the new diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 620c54c2dc0a5..cc173381879a6 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all jobs in the given SparkContext. */ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val jobId = Option(request.getParameter("id")).map(_.toInt) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) jobId.foreach { id => if (jobProgresslistener.activeJobs.contains(id)) { sc.foreach(_.cancelJob(id)) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 8ee70d27cc09f..b164f32b62e97 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { - val poolName = Option(request.getParameter("poolname")).map { poolname => + // stripXSS is called first to remove suspicious characters used in XSS attacks + val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname => UIUtils.decodeURLParameter(poolname) }.getOrElse { throw new IllegalArgumentException(s"Missing poolname parameter") diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8c7cefe200739..6b3dadc333316 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -70,8 +70,6 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // if we find that it's okay. private val MAX_TIMELINE_TASKS = parent.conf.getInt("spark.ui.timeline.tasks.maximum", 1000) - private val displayPeakExecutionMemory = parent.conf.getBoolean("spark.sql.unsafe.enabled", true) - private def getLocalitySummaryString(stageData: StageUIData): String = { val localities = stageData.taskData.values.map(_.taskInfo.taskLocality) val localityCounts = localities.groupBy(identity).mapValues(_.size) @@ -89,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def render(request: HttpServletRequest): Seq[Node] = { progressListener.synchronized { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterAttempt = request.getParameter("attempt") + val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt")) require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") - val parameterTaskPage = request.getParameter("task.page") - val parameterTaskSortColumn = request.getParameter("task.sort") - val parameterTaskSortDesc = request.getParameter("task.desc") - val parameterTaskPageSize = request.getParameter("task.pageSize") - val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") + val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page")) + val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort")) + val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc")) + val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize")) + val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize")) val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn => @@ -144,7 +143,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val allAccumulables = progressListener.stageIdToData((stageId, stageAttemptId)).accumulables val externalAccumulables = allAccumulables.values.filter { acc => !acc.internal } - val hasAccumulators = externalAccumulables.size > 0 + val hasAccumulators = externalAccumulables.nonEmpty val summary =
    @@ -252,15 +251,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Getting Result Time - {if (displayPeakExecutionMemory) { -
  • - - - Peak Execution Memory - -
  • - }} +
  • + + + Peak Execution Memory + +
  • @@ -343,7 +340,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = - if (validTasks.size == 0) { + if (validTasks.isEmpty) { None } else { @@ -532,13 +529,9 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { {serializationQuantiles} , {gettingResultQuantiles}, - if (displayPeakExecutionMemory) { - - {peakExecutionMemoryQuantiles} - - } else { - Nil - }, + + {peakExecutionMemoryQuantiles} + , if (stageData.hasInput) {inputQuantiles} else Nil, if (stageData.hasOutput) {outputQuantiles} else Nil, if (stageData.hasShuffleRead) { @@ -794,8 +787,8 @@ private[ui] object StagePage { info: TaskInfo, metrics: TaskMetricsUIData, currentTime: Long): Long = { if (info.finished) { val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) + val executorOverhead = metrics.executorDeserializeTime + + metrics.resultSerializationTime math.max( 0, totalExecutionTime - metrics.executorRunTime - executorOverhead - @@ -880,7 +873,7 @@ private[ui] class TaskDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) - private var _slicedTaskIds: Set[Long] = null + private var _slicedTaskIds: Set[Long] = _ override def dataSize: Int = data.size @@ -895,10 +888,8 @@ private[ui] class TaskDataSource( private def taskRow(taskData: TaskUIData): TaskTableRowData = { val info = taskData.taskInfo val metrics = taskData.metrics - val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) - else metrics.map(_.executorRunTime).getOrElse(1L) - val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) - else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val duration = taskData.taskDuration.getOrElse(1L) + val formatDuration = taskData.taskDuration.map(d => UIUtils.formatDuration(d)).getOrElse("") val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) @@ -1166,9 +1157,6 @@ private[ui] class TaskPagedTable( desc: Boolean, executorsListener: ExecutorsListener) extends PagedTable[TaskTableRowData] { - // We only track peak memory used for unsafe operators - private val displayPeakExecutionMemory = conf.getBoolean("spark.sql.unsafe.enabled", true) - override def tableId: String = "task-table" override def tableCssClass: String = @@ -1217,14 +1205,8 @@ private[ui] class TaskPagedTable( ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), ("GC Time", ""), ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ - { - if (displayPeakExecutionMemory) { - Seq(("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) - } else { - Nil - } - } ++ + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME), + ("Peak Execution Memory", TaskDetailsClassNames.PEAK_EXECUTION_MEMORY)) ++ {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ @@ -1316,11 +1298,9 @@ private[ui] class TaskPagedTable( {UIUtils.formatDuration(task.gettingResultTime)} - {if (displayPeakExecutionMemory) { - - {Utils.bytesToString(task.peakExecutionMemoryUsed)} - - }} + + {Utils.bytesToString(task.peakExecutionMemoryUsed)} + {if (task.accumulators.nonEmpty) { {Unparsed(task.accumulators.get)} }} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index c9d0431e2d2f7..a28daf7f90451 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -42,15 +42,17 @@ private[ui] class StageTableBase( isFairScheduler: Boolean, killEnabled: Boolean, isFailedStage: Boolean) { - val allParameters = request.getParameterMap().asScala.toMap + // stripXSS is called to remove suspicious characters used in XSS attacks + val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS)) val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag)) .map(para => para._1 + "=" + para._2(0)) - val parameterStagePage = request.getParameter(stageTag + ".page") - val parameterStageSortColumn = request.getParameter(stageTag + ".sort") - val parameterStageSortDesc = request.getParameter(stageTag + ".desc") - val parameterStagePageSize = request.getParameter(stageTag + ".pageSize") - val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize") + val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page")) + val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort")) + val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc")) + val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize")) + val parameterStagePrevPageSize = + UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize")) val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1) val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn => @@ -149,7 +151,8 @@ private[ui] class StagePagedTable( override def tableId: String = stageTag + "-table" override def tableCssClass: String = - "table table-bordered table-condensed table-striped table-head-clickable" + "table table-bordered table-condensed table-striped " + + "table-head-clickable table-cell-width-limited" override def pageSizeFormField: String = stageTag + ".pageSize" @@ -299,7 +302,7 @@ private[ui] class StagePagedTable( {UIUtils.makeProgressBar(started = stageData.numActiveTasks, completed = stageData.completedIndices.size, failed = stageData.numFailedTasks, - skipped = 0, killed = stageData.numKilledTasks, total = info.numTasks)} + skipped = 0, reasonToNumKilled = stageData.reasonToNumKilled, total = info.numTasks)} {data.inputReadWithUnit} {data.outputWriteWithUnit} @@ -411,7 +414,7 @@ private[ui] class StageDataSource( // so that we can avoid creating duplicate contents during sorting the data private val data = stages.map(stageRow).sorted(ordering(sortColumn, desc)) - private var _slicedStageIds: Set[Int] = null + private var _slicedStageIds: Set[Int] = _ override def dataSize: Int = data.size @@ -511,4 +514,3 @@ private[ui] class StageDataSource( } } } - diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala index c1f25114371f1..799d769626395 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs import javax.servlet.http.HttpServletRequest import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} +import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils} /** Web UI showing progress status of all stages in the given SparkContext. */ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") { @@ -39,10 +39,11 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages" def handleKillRequest(request: HttpServletRequest): Unit = { if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) { - val stageId = Option(request.getParameter("id")).map(_.toInt) + // stripXSS is called first to remove suspicious characters used in XSS attacks + val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt) stageId.foreach { id => if (progressListener.activeStages.contains(id)) { - sc.foreach(_.cancelStage(id)) + sc.foreach(_.cancelStage(id, "killed via the Web UI")) // Do a quick pause here to give Spark time to kill the stage so it shows up as // killed after the refresh. Note that this will block the serving thread so the // time should be limited in duration. diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index f4a04609c4c69..ac1a74ad8029d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.{HashMap, LinkedHashMap} import org.apache.spark.JobExecutionStatus -import org.apache.spark.executor.{ShuffleReadMetrics, ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.executor._ import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.AccumulatorContext import org.apache.spark.util.collection.OpenHashSet @@ -32,7 +32,7 @@ private[spark] object UIData { var taskTime : Long = 0 var failedTasks : Int = 0 var succeededTasks : Int = 0 - var killedTasks : Int = 0 + var reasonToNumKilled : Map[String, Int] = Map.empty var inputBytes : Long = 0 var inputRecords : Long = 0 var outputBytes : Long = 0 @@ -43,6 +43,7 @@ private[spark] object UIData { var shuffleWriteRecords : Long = 0 var memoryBytesSpilled : Long = 0 var diskBytesSpilled : Long = 0 + var isBlacklisted : Int = 0 } class JobUIData( @@ -63,7 +64,7 @@ private[spark] object UIData { var numCompletedTasks: Int = 0, var numSkippedTasks: Int = 0, var numFailedTasks: Int = 0, - var numKilledTasks: Int = 0, + var reasonToNumKilled: Map[String, Int] = Map.empty, /* Stages */ var numActiveStages: Int = 0, // This needs to be a set instead of a simple count to prevent double-counting of rerun stages: @@ -77,7 +78,7 @@ private[spark] object UIData { var numCompleteTasks: Int = _ var completedIndices = new OpenHashSet[Int]() var numFailedTasks: Int = _ - var numKilledTasks: Int = _ + var reasonToNumKilled: Map[String, Int] = Map.empty var executorRunTime: Long = _ var executorCpuTime: Long = _ @@ -92,6 +93,7 @@ private[spark] object UIData { var shuffleWriteRecords: Long = _ var memoryBytesSpilled: Long = _ var diskBytesSpilled: Long = _ + var isBlacklisted: Int = _ var schedulingPool: String = "" var description: Option[String] = None @@ -127,6 +129,14 @@ private[spark] object UIData { def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = { _metrics = TaskUIData.toTaskMetricsUIData(metrics) } + + def taskDuration: Option[Long] = { + if (taskInfo.status == "RUNNING") { + Some(_taskInfo.timeRunning(System.currentTimeMillis)) + } else { + _metrics.map(_.executorRunTime) + } + } } object TaskUIData { @@ -147,9 +157,8 @@ private[spark] object UIData { memoryBytesSpilled = m.memoryBytesSpilled, diskBytesSpilled = m.diskBytesSpilled, peakExecutionMemory = m.peakExecutionMemory, - inputMetrics = InputMetricsUIData(m.inputMetrics.bytesRead, m.inputMetrics.recordsRead), - outputMetrics = - OutputMetricsUIData(m.outputMetrics.bytesWritten, m.outputMetrics.recordsWritten), + inputMetrics = InputMetricsUIData(m.inputMetrics), + outputMetrics = OutputMetricsUIData(m.outputMetrics), shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics), shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics)) } @@ -171,11 +180,12 @@ private[spark] object UIData { speculative = taskInfo.speculative ) newTaskInfo.gettingResultTime = taskInfo.gettingResultTime - newTaskInfo.accumulables ++= taskInfo.accumulables.filter { + newTaskInfo.setAccumulables(taskInfo.accumulables.filter { accum => !accum.internal && accum.metadata != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER) - } + }) newTaskInfo.finishTime = taskInfo.finishTime newTaskInfo.failed = taskInfo.failed + newTaskInfo.killed = taskInfo.killed newTaskInfo } } @@ -197,8 +207,32 @@ private[spark] object UIData { shuffleWriteMetrics: ShuffleWriteMetricsUIData) case class InputMetricsUIData(bytesRead: Long, recordsRead: Long) + object InputMetricsUIData { + def apply(metrics: InputMetrics): InputMetricsUIData = { + if (metrics.bytesRead == 0 && metrics.recordsRead == 0) { + EMPTY + } else { + new InputMetricsUIData( + bytesRead = metrics.bytesRead, + recordsRead = metrics.recordsRead) + } + } + private val EMPTY = InputMetricsUIData(0, 0) + } case class OutputMetricsUIData(bytesWritten: Long, recordsWritten: Long) + object OutputMetricsUIData { + def apply(metrics: OutputMetrics): OutputMetricsUIData = { + if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0) { + EMPTY + } else { + new OutputMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten) + } + } + private val EMPTY = OutputMetricsUIData(0, 0) + } case class ShuffleReadMetricsUIData( remoteBlocksFetched: Long, @@ -212,17 +246,30 @@ private[spark] object UIData { object ShuffleReadMetricsUIData { def apply(metrics: ShuffleReadMetrics): ShuffleReadMetricsUIData = { - new ShuffleReadMetricsUIData( - remoteBlocksFetched = metrics.remoteBlocksFetched, - localBlocksFetched = metrics.localBlocksFetched, - remoteBytesRead = metrics.remoteBytesRead, - localBytesRead = metrics.localBytesRead, - fetchWaitTime = metrics.fetchWaitTime, - recordsRead = metrics.recordsRead, - totalBytesRead = metrics.totalBytesRead, - totalBlocksFetched = metrics.totalBlocksFetched - ) + if ( + metrics.remoteBlocksFetched == 0 && + metrics.localBlocksFetched == 0 && + metrics.remoteBytesRead == 0 && + metrics.localBytesRead == 0 && + metrics.fetchWaitTime == 0 && + metrics.recordsRead == 0 && + metrics.totalBytesRead == 0 && + metrics.totalBlocksFetched == 0) { + EMPTY + } else { + new ShuffleReadMetricsUIData( + remoteBlocksFetched = metrics.remoteBlocksFetched, + localBlocksFetched = metrics.localBlocksFetched, + remoteBytesRead = metrics.remoteBytesRead, + localBytesRead = metrics.localBytesRead, + fetchWaitTime = metrics.fetchWaitTime, + recordsRead = metrics.recordsRead, + totalBytesRead = metrics.totalBytesRead, + totalBlocksFetched = metrics.totalBlocksFetched + ) + } } + private val EMPTY = ShuffleReadMetricsUIData(0, 0, 0, 0, 0, 0, 0, 0) } case class ShuffleWriteMetricsUIData( @@ -232,12 +279,17 @@ private[spark] object UIData { object ShuffleWriteMetricsUIData { def apply(metrics: ShuffleWriteMetrics): ShuffleWriteMetricsUIData = { - new ShuffleWriteMetricsUIData( - bytesWritten = metrics.bytesWritten, - recordsWritten = metrics.recordsWritten, - writeTime = metrics.writeTime - ) + if (metrics.bytesWritten == 0 && metrics.recordsWritten == 0 && metrics.writeTime == 0) { + EMPTY + } else { + new ShuffleWriteMetricsUIData( + bytesWritten = metrics.bytesWritten, + recordsWritten = metrics.recordsWritten, + writeTime = metrics.writeTime + ) + } } + private val EMPTY = ShuffleWriteMetricsUIData(0, 0, 0) } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 0e330879d50f9..43bfe0aacf35b 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -222,7 +222,12 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - val label = s"${node.name} [${node.id}]\n${node.callsite}" + val isCached = if (node.cached) { + " [Cached]" + } else { + "" + } + val label = s"${node.name} [${node.id}]$isCached\n${node.callsite}" s"""${node.id} [label="${StringEscapeUtils.escapeJava(label)}"]""" } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 227e940c9c50c..317e0aa5ea25c 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { - val parameterId = request.getParameter("id") + // stripXSS is called first to remove suspicious characters used in XSS attacks + val parameterId = UIUtils.stripXSS(request.getParameter("id")) require(parameterId != null && parameterId.nonEmpty, "Missing id parameter") - val parameterBlockPage = request.getParameter("block.page") - val parameterBlockSortColumn = request.getParameter("block.sort") - val parameterBlockSortDesc = request.getParameter("block.desc") - val parameterBlockPageSize = request.getParameter("block.pageSize") - val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize") + val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page")) + val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort")) + val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc")) + val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize")) + val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize")) val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") @@ -147,7 +148,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { /** Header fields for the worker table */ private def workerHeader = Seq( "Host", - "Memory Usage", + "On Heap Memory Usage", + "Off Heap Memory Usage", "Disk Usage") /** Render an HTML row representing a worker */ @@ -155,8 +157,12 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { {worker.address} - {Utils.bytesToString(worker.memoryUsed)} - ({Utils.bytesToString(worker.memoryRemaining)} Remaining) + {Utils.bytesToString(worker.onHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.onHeapMemoryRemaining.getOrElse(0L))} Remaining) + + + {Utils.bytesToString(worker.offHeapMemoryUsed.getOrElse(0L))} + ({Utils.bytesToString(worker.offHeapMemoryRemaining.getOrElse(0L))} Remaining) {Utils.bytesToString(worker.diskUsed)} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 76d7c6d414bcf..aa84788f1df88 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -151,7 +151,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { /** Render a stream block */ private def streamBlockTableRow(block: (BlockId, Seq[BlockUIData])): Seq[Node] = { val replications = block._2 - assert(replications.size > 0) // This must be true because it's the result of "groupBy" + assert(replications.nonEmpty) // This must be true because it's the result of "groupBy" if (replications.size == 1) { streamBlockTableSubrow(block._1, replications.head, replications.size, true) } else { diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index c212362557be6..148efb134e14f 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index d3ddd39131326..1a9a6929541aa 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -59,8 +59,9 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } /** - * Returns true if this accumulator has been registered. Note that all accumulators must be - * registered before use, or it will throw exception. + * Returns true if this accumulator has been registered. + * + * @note All accumulators must be registered before use, or it will throw exception. */ final def isRegistered: Boolean = metadata != null && AccumulatorContext.get(metadata.id).isDefined @@ -84,7 +85,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { */ final def name: Option[String] = { assertMetadataNotNull() - metadata.name + + if (atDriverSide) { + metadata.name.orElse(AccumulatorContext.get(id).flatMap(_.metadata.name)) + } else { + metadata.name + } } /** @@ -160,7 +166,17 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - copyAcc.metadata = metadata + val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX) + if (isInternalAcc) { + // Do not serialize the name of internal accumulator and send it to executor. + copyAcc.metadata = metadata.copy(name = None) + } else { + // For non-internal accumulators, we still need to send the name because users may need to + // access the accumulator name at executor side, or they may keep the accumulators sent from + // executors and access the name when the registered accumulator is already garbage + // collected(e.g. SQLMetrics). + copyAcc.metadata = metadata + } copyAcc } else { this @@ -223,7 +239,7 @@ private[spark] object AccumulatorContext { * Registers an [[AccumulatorV2]] created on the driver such that it can be used on the executors. * * All accumulators registered here can later be used as a container for accumulating partial - * values across multiple tasks. This is what [[org.apache.spark.scheduler.DAGScheduler]] does. + * values across multiple tasks. This is what `org.apache.spark.scheduler.DAGScheduler` does. * Note: if an accumulator is registered here, it should also be registered with the active * context cleaner for cleanup so as to avoid memory leaks. * @@ -262,23 +278,13 @@ private[spark] object AccumulatorContext { originals.clear() } - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - originals.values().asScala.find { ref => - val acc = ref.get - acc != null && acc.name.isDefined && acc.name.get == name - }.map(_.get) - } - // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } /** - * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 64-bit integers. + * An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers. * * @since 2.0.0 */ diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala index dce2ac63a664c..50dc948e6c410 100644 --- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala @@ -23,11 +23,10 @@ import java.nio.ByteBuffer import org.apache.spark.storage.StorageUtils /** - * Reads data from a ByteBuffer, and optionally cleans it up using StorageUtils.dispose() - * at the end of the stream (e.g. to close a memory-mapped file). + * Reads data from a ByteBuffer. */ private[spark] -class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = false) +class ByteBufferInputStream(private var buffer: ByteBuffer) extends InputStream { override def read(): Int = { @@ -72,9 +71,6 @@ class ByteBufferInputStream(private var buffer: ByteBuffer, dispose: Boolean = f */ private def cleanUp() { if (buffer != null) { - if (dispose) { - StorageUtils.dispose(buffer) - } buffer = null } } diff --git a/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala new file mode 100644 index 0000000000000..d73901686b705 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CommandLineUtils.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.PrintStream + +import org.apache.spark.SparkException + +/** + * Contains basic command line parsing functionality and methods to parse some common Spark CLI + * options. + */ +private[spark] trait CommandLineUtils { + + // Exposed for testing + private[spark] var exitFn: Int => Unit = (exitCode: Int) => System.exit(exitCode) + + private[spark] var printStream: PrintStream = System.err + + // scalastyle:off println + + private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str) + + private[spark] def printErrorAndExit(str: String): Unit = { + printStream.println("Error: " + str) + printStream.println("Run with --help for usage help or --verbose for debug output") + exitFn(1) + } + + // scalastyle:on println + + private[spark] def parseSparkConfProperty(pair: String): (String, String) = { + pair.split("=", 2).toSeq match { + case Seq(k, v) => (k, v) + case _ => printErrorAndExit(s"Spark config without '=': $pair") + throw new SparkException(s"Spark config without '=': $pair") + } + } + + def main(args: Array[String]): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index c11eb3ffa4601..8296c4294242c 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -107,20 +107,20 @@ private[spark] object JsonProtocol { def stageSubmittedToJson(stageSubmitted: SparkListenerStageSubmitted): JValue = { val stageInfo = stageInfoToJson(stageSubmitted.stageInfo) val properties = propertiesToJson(stageSubmitted.properties) - ("Event" -> Utils.getFormattedClassName(stageSubmitted)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageSubmitted) ~ ("Stage Info" -> stageInfo) ~ ("Properties" -> properties) } def stageCompletedToJson(stageCompleted: SparkListenerStageCompleted): JValue = { val stageInfo = stageInfoToJson(stageCompleted.stageInfo) - ("Event" -> Utils.getFormattedClassName(stageCompleted)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.stageCompleted) ~ ("Stage Info" -> stageInfo) } def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = { val taskInfo = taskStart.taskInfo - ("Event" -> Utils.getFormattedClassName(taskStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskStart) ~ ("Stage ID" -> taskStart.stageId) ~ ("Stage Attempt ID" -> taskStart.stageAttemptId) ~ ("Task Info" -> taskInfoToJson(taskInfo)) @@ -128,7 +128,7 @@ private[spark] object JsonProtocol { def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = { val taskInfo = taskGettingResult.taskInfo - ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskGettingResult) ~ ("Task Info" -> taskInfoToJson(taskInfo)) } @@ -137,7 +137,7 @@ private[spark] object JsonProtocol { val taskInfo = taskEnd.taskInfo val taskMetrics = taskEnd.taskMetrics val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing - ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.taskEnd) ~ ("Stage ID" -> taskEnd.stageId) ~ ("Stage Attempt ID" -> taskEnd.stageAttemptId) ~ ("Task Type" -> taskEnd.taskType) ~ @@ -148,7 +148,7 @@ private[spark] object JsonProtocol { def jobStartToJson(jobStart: SparkListenerJobStart): JValue = { val properties = propertiesToJson(jobStart.properties) - ("Event" -> Utils.getFormattedClassName(jobStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobStart) ~ ("Job ID" -> jobStart.jobId) ~ ("Submission Time" -> jobStart.time) ~ ("Stage Infos" -> jobStart.stageInfos.map(stageInfoToJson)) ~ // Added in Spark 1.2.0 @@ -158,7 +158,7 @@ private[spark] object JsonProtocol { def jobEndToJson(jobEnd: SparkListenerJobEnd): JValue = { val jobResult = jobResultToJson(jobEnd.jobResult) - ("Event" -> Utils.getFormattedClassName(jobEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.jobEnd) ~ ("Job ID" -> jobEnd.jobId) ~ ("Completion Time" -> jobEnd.time) ~ ("Job Result" -> jobResult) @@ -170,7 +170,7 @@ private[spark] object JsonProtocol { val sparkProperties = mapToJson(environmentDetails("Spark Properties").toMap) val systemProperties = mapToJson(environmentDetails("System Properties").toMap) val classpathEntries = mapToJson(environmentDetails("Classpath Entries").toMap) - ("Event" -> Utils.getFormattedClassName(environmentUpdate)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.environmentUpdate) ~ ("JVM Information" -> jvmInformation) ~ ("Spark Properties" -> sparkProperties) ~ ("System Properties" -> systemProperties) ~ @@ -179,26 +179,28 @@ private[spark] object JsonProtocol { def blockManagerAddedToJson(blockManagerAdded: SparkListenerBlockManagerAdded): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId) - ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerAdded) ~ ("Block Manager ID" -> blockManagerId) ~ ("Maximum Memory" -> blockManagerAdded.maxMem) ~ - ("Timestamp" -> blockManagerAdded.time) + ("Timestamp" -> blockManagerAdded.time) ~ + ("Maximum Onheap Memory" -> blockManagerAdded.maxOnHeapMem) ~ + ("Maximum Offheap Memory" -> blockManagerAdded.maxOffHeapMem) } def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId) - ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockManagerRemoved) ~ ("Block Manager ID" -> blockManagerId) ~ ("Timestamp" -> blockManagerRemoved.time) } def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = { - ("Event" -> Utils.getFormattedClassName(unpersistRDD)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.unpersistRDD) ~ ("RDD ID" -> unpersistRDD.rddId) } def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = { - ("Event" -> Utils.getFormattedClassName(applicationStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationStart) ~ ("App Name" -> applicationStart.appName) ~ ("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~ ("Timestamp" -> applicationStart.time) ~ @@ -208,33 +210,33 @@ private[spark] object JsonProtocol { } def applicationEndToJson(applicationEnd: SparkListenerApplicationEnd): JValue = { - ("Event" -> Utils.getFormattedClassName(applicationEnd)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.applicationEnd) ~ ("Timestamp" -> applicationEnd.time) } def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { - ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorAdded) ~ ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { - ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.executorRemoved) ~ ("Timestamp" -> executorRemoved.time) ~ ("Executor ID" -> executorRemoved.executorId) ~ ("Removed Reason" -> executorRemoved.reason) } def logStartToJson(logStart: SparkListenerLogStart): JValue = { - ("Event" -> Utils.getFormattedClassName(logStart)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.logStart) ~ ("Spark Version" -> SPARK_VERSION) } def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { val execId = metricsUpdate.execId val accumUpdates = metricsUpdate.accumUpdates - ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.metricsUpdate) ~ ("Executor ID" -> execId) ~ ("Metrics Updated" -> accumUpdates.map { case (taskId, stageId, stageAttemptId, updates) => ("Task ID" -> taskId) ~ @@ -264,8 +266,7 @@ private[spark] object JsonProtocol { ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ ("Failure Reason" -> failureReason) ~ - ("Accumulables" -> JArray( - stageInfo.accumulables.values.map(accumulableInfoToJson).toList)) + ("Accumulables" -> accumulablesToJson(stageInfo.accumulables.values)) } def taskInfoToJson(taskInfo: TaskInfo): JValue = { @@ -281,7 +282,15 @@ private[spark] object JsonProtocol { ("Finish Time" -> taskInfo.finishTime) ~ ("Failed" -> taskInfo.failed) ~ ("Killed" -> taskInfo.killed) ~ - ("Accumulables" -> JArray(taskInfo.accumulables.toList.map(accumulableInfoToJson))) + ("Accumulables" -> accumulablesToJson(taskInfo.accumulables)) + } + + private lazy val accumulableBlacklist = Set("internal.metrics.updatedBlockStatuses") + + def accumulablesToJson(accumulables: Traversable[AccumulableInfo]): JArray = { + JArray(accumulables + .filterNot(_.name.exists(accumulableBlacklist.contains)) + .toList.map(accumulableInfoToJson)) } def accumulableInfoToJson(accumulableInfo: AccumulableInfo): JValue = { @@ -376,7 +385,7 @@ private[spark] object JsonProtocol { ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => val stackTrace = stackTraceToJson(exceptionFailure.stackTrace) - val accumUpdates = JArray(exceptionFailure.accumUpdates.map(accumulableInfoToJson).toList) + val accumUpdates = accumulablesToJson(exceptionFailure.accumUpdates) ("Class Name" -> exceptionFailure.className) ~ ("Description" -> exceptionFailure.description) ~ ("Stack Trace" -> stackTrace) ~ @@ -390,6 +399,8 @@ private[spark] object JsonProtocol { ("Executor ID" -> executorId) ~ ("Exit Caused By App" -> exitCausedByApp) ~ ("Loss Reason" -> reason.map(_.toString)) + case taskKilled: TaskKilled => + ("Kill Reason" -> taskKilled.reason) case _ => Utils.emptyJson } ("Reason" -> reason) ~ json @@ -485,7 +496,7 @@ private[spark] object JsonProtocol { * JSON deserialization methods for SparkListenerEvents | * ---------------------------------------------------- */ - def sparkEventFromJson(json: JValue): SparkListenerEvent = { + private object SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES { val stageSubmitted = Utils.getFormattedClassName(SparkListenerStageSubmitted) val stageCompleted = Utils.getFormattedClassName(SparkListenerStageCompleted) val taskStart = Utils.getFormattedClassName(SparkListenerTaskStart) @@ -503,6 +514,10 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + } + + def sparkEventFromJson(json: JValue): SparkListenerEvent = { + import SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES._ (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -540,7 +555,8 @@ private[spark] object JsonProtocol { def taskStartFromJson(json: JValue): SparkListenerTaskStart = { val stageId = (json \ "Stage ID").extract[Int] - val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val stageAttemptId = + Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskInfo = taskInfoFromJson(json \ "Task Info") SparkListenerTaskStart(stageId, stageAttemptId, taskInfo) } @@ -552,7 +568,8 @@ private[spark] object JsonProtocol { def taskEndFromJson(json: JValue): SparkListenerTaskEnd = { val stageId = (json \ "Stage ID").extract[Int] - val stageAttemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val stageAttemptId = + Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val taskType = (json \ "Task Type").extract[String] val taskEndReason = taskEndReasonFromJson(json \ "Task End Reason") val taskInfo = taskInfoFromJson(json \ "Task Info") @@ -597,7 +614,9 @@ private[spark] object JsonProtocol { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) - SparkListenerBlockManagerAdded(time, blockManagerId, maxMem) + val maxOnHeapMem = Utils.jsonOption(json \ "Maximum Onheap Memory").map(_.extract[Long]) + val maxOffHeapMem = Utils.jsonOption(json \ "Maximum Offheap Memory").map(_.extract[Long]) + SparkListenerBlockManagerAdded(time, blockManagerId, maxMem, maxOnHeapMem, maxOffHeapMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { @@ -662,20 +681,22 @@ private[spark] object JsonProtocol { def stageInfoFromJson(json: JValue): StageInfo = { val stageId = (json \ "Stage ID").extract[Int] - val attemptId = (json \ "Stage Attempt ID").extractOpt[Int].getOrElse(0) + val attemptId = Utils.jsonOption(json \ "Stage Attempt ID").map(_.extract[Int]).getOrElse(0) val stageName = (json \ "Stage Name").extract[String] val numTasks = (json \ "Number of Tasks").extract[Int] val rddInfos = (json \ "RDD Info").extract[List[JValue]].map(rddInfoFromJson) val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) - val details = (json \ "Details").extractOpt[String].getOrElse("") + val details = Utils.jsonOption(json \ "Details").map(_.extract[String]).getOrElse("") val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) - val accumulatedValues = (json \ "Accumulables").extractOpt[List[JValue]] match { - case Some(values) => values.map(accumulableInfoFromJson) - case None => Seq[AccumulableInfo]() + val accumulatedValues = { + Utils.jsonOption(json \ "Accumulables").map(_.extract[List[JValue]]) match { + case Some(values) => values.map(accumulableInfoFromJson) + case None => Seq[AccumulableInfo]() + } } val stageInfo = new StageInfo( @@ -692,17 +713,17 @@ private[spark] object JsonProtocol { def taskInfoFromJson(json: JValue): TaskInfo = { val taskId = (json \ "Task ID").extract[Long] val index = (json \ "Index").extract[Int] - val attempt = (json \ "Attempt").extractOpt[Int].getOrElse(1) + val attempt = Utils.jsonOption(json \ "Attempt").map(_.extract[Int]).getOrElse(1) val launchTime = (json \ "Launch Time").extract[Long] - val executorId = (json \ "Executor ID").extract[String] - val host = (json \ "Host").extract[String] + val executorId = (json \ "Executor ID").extract[String].intern() + val host = (json \ "Host").extract[String].intern() val taskLocality = TaskLocality.withName((json \ "Locality").extract[String]) - val speculative = (json \ "Speculative").extractOpt[Boolean].getOrElse(false) + val speculative = Utils.jsonOption(json \ "Speculative").exists(_.extract[Boolean]) val gettingResultTime = (json \ "Getting Result Time").extract[Long] val finishTime = (json \ "Finish Time").extract[Long] val failed = (json \ "Failed").extract[Boolean] - val killed = (json \ "Killed").extractOpt[Boolean].getOrElse(false) - val accumulables = (json \ "Accumulables").extractOpt[Seq[JValue]] match { + val killed = Utils.jsonOption(json \ "Killed").exists(_.extract[Boolean]) + val accumulables = Utils.jsonOption(json \ "Accumulables").map(_.extract[Seq[JValue]]) match { case Some(values) => values.map(accumulableInfoFromJson) case None => Seq[AccumulableInfo]() } @@ -713,18 +734,19 @@ private[spark] object JsonProtocol { taskInfo.finishTime = finishTime taskInfo.failed = failed taskInfo.killed = killed - accumulables.foreach { taskInfo.accumulables += _ } + taskInfo.setAccumulables(accumulables) taskInfo } def accumulableInfoFromJson(json: JValue): AccumulableInfo = { val id = (json \ "ID").extract[Long] - val name = (json \ "Name").extractOpt[String] + val name = Utils.jsonOption(json \ "Name").map(_.extract[String]) val update = Utils.jsonOption(json \ "Update").map { v => accumValueFromJson(name, v) } val value = Utils.jsonOption(json \ "Value").map { v => accumValueFromJson(name, v) } - val internal = (json \ "Internal").extractOpt[Boolean].getOrElse(false) - val countFailedValues = (json \ "Count Failed Values").extractOpt[Boolean].getOrElse(false) - val metadata = (json \ "Metadata").extractOpt[String] + val internal = Utils.jsonOption(json \ "Internal").exists(_.extract[Boolean]) + val countFailedValues = + Utils.jsonOption(json \ "Count Failed Values").exists(_.extract[Boolean]) + val metadata = Utils.jsonOption(json \ "Metadata").map(_.extract[String]) new AccumulableInfo(id, name, update, value, internal, countFailedValues, metadata) } @@ -782,9 +804,11 @@ private[spark] object JsonProtocol { readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) - readMetrics.incLocalBytesRead((readJson \ "Local Bytes Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incLocalBytesRead( + Utils.jsonOption(readJson \ "Local Bytes Read").map(_.extract[Long]).getOrElse(0L)) readMetrics.incFetchWaitTime((readJson \ "Fetch Wait Time").extract[Long]) - readMetrics.incRecordsRead((readJson \ "Total Records Read").extractOpt[Long].getOrElse(0L)) + readMetrics.incRecordsRead( + Utils.jsonOption(readJson \ "Total Records Read").map(_.extract[Long]).getOrElse(0L)) metrics.mergeShuffleReadMetrics() } @@ -793,8 +817,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) - writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written") - .extractOpt[Long].getOrElse(0L)) + writeMetrics.incRecordsWritten( + Utils.jsonOption(writeJson \ "Shuffle Records Written").map(_.extract[Long]).getOrElse(0L)) writeMetrics.incWriteTime((writeJson \ "Shuffle Write Time").extract[Long]) } @@ -802,14 +826,16 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Output Metrics").foreach { outJson => val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) - outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L)) + outputMetrics.setRecordsWritten( + Utils.jsonOption(outJson \ "Records Written").map(_.extract[Long]).getOrElse(0L)) } // Input metrics Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incRecordsRead( + Utils.jsonOption(inJson \ "Records Read").map(_.extract[Long]).getOrElse(0L)) } // Updated blocks @@ -824,7 +850,7 @@ private[spark] object JsonProtocol { metrics } - def taskEndReasonFromJson(json: JValue): TaskEndReason = { + private object TASK_END_REASON_FORMATTED_CLASS_NAMES { val success = Utils.getFormattedClassName(Success) val resubmitted = Utils.getFormattedClassName(Resubmitted) val fetchFailed = Utils.getFormattedClassName(FetchFailed) @@ -834,6 +860,10 @@ private[spark] object JsonProtocol { val taskCommitDenied = Utils.getFormattedClassName(TaskCommitDenied) val executorLostFailure = Utils.getFormattedClassName(ExecutorLostFailure) val unknownReason = Utils.getFormattedClassName(UnknownReason) + } + + def taskEndReasonFromJson(json: JValue): TaskEndReason = { + import TASK_END_REASON_FORMATTED_CLASS_NAMES._ (json \ "Reason").extract[String] match { case `success` => Success @@ -850,7 +880,8 @@ private[spark] object JsonProtocol { val className = (json \ "Class Name").extract[String] val description = (json \ "Description").extract[String] val stackTrace = stackTraceFromJson(json \ "Stack Trace") - val fullStackTrace = (json \ "Full Stack Trace").extractOpt[String].orNull + val fullStackTrace = + Utils.jsonOption(json \ "Full Stack Trace").map(_.extract[String]).orNull // Fallback on getting accumulator updates from TaskMetrics, which was logged in Spark 1.x val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates") .map(_.extract[List[JValue]].map(accumulableInfoFromJson)) @@ -859,7 +890,10 @@ private[spark] object JsonProtocol { })) ExceptionFailure(className, description, stackTrace, fullStackTrace, None, accumUpdates) case `taskResultLost` => TaskResultLost - case `taskKilled` => TaskKilled + case `taskKilled` => + val killReason = Utils.jsonOption(json \ "Kill Reason") + .map(_.extract[String]).getOrElse("unknown reason") + TaskKilled(killReason) case `taskCommitDenied` => // Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON // de/serialization logic was not added until 1.5.1. To provide backward compatibility @@ -885,15 +919,19 @@ private[spark] object JsonProtocol { if (json == JNothing) { return null } - val executorId = (json \ "Executor ID").extract[String] - val host = (json \ "Host").extract[String] + val executorId = (json \ "Executor ID").extract[String].intern() + val host = (json \ "Host").extract[String].intern() val port = (json \ "Port").extract[Int] BlockManagerId(executorId, host, port) } - def jobResultFromJson(json: JValue): JobResult = { + private object JOB_RESULT_FORMATTED_CLASS_NAMES { val jobSucceeded = Utils.getFormattedClassName(JobSucceeded) val jobFailed = Utils.getFormattedClassName(JobFailed) + } + + def jobResultFromJson(json: JValue): JobResult = { + import JOB_RESULT_FORMATTED_CLASS_NAMES._ (json \ "Result").extract[String] match { case `jobSucceeded` => JobSucceeded diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala index 79fc2e94599c7..fa5ad4e8d81e1 100644 --- a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -52,7 +52,7 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { * Post the event to all registered listeners. The `postToAll` caller should guarantee calling * `postToAll` in the same thread for all events. */ - final def postToAll(event: E): Unit = { + def postToAll(event: E): Unit = { // JavaConverters can create a JIterableWrapper if we use asScala. // However, this method will be called frequently. To avoid the wrapper cost, here we use // Java Iterator directly. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4e..ce06e18879a49 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index e3b588374ce1a..e5cccf39f9455 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -23,12 +23,12 @@ import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} private[spark] object RpcUtils { /** - * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + * Retrieve a `RpcEndpointRef` which is located in the driver via its name. */ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") + Utils.checkHost(driverHost) rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 386fdfd218a88..3bfdf95db84c6 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -350,7 +350,7 @@ object SizeEstimator extends Logging { // 3. consistent fields layouts throughout the hierarchy: This means we should layout // superclass first. And we can use superclass's shellSize as a starting point to layout the // other fields in this class. - // 4. class alignment: HotSpot rounds field blocks up to to HeapOopSize not 4 bytes, confirmed + // 4. class alignment: HotSpot rounds field blocks up to HeapOopSize not 4 bytes, confirmed // with Aleksey. see https://bugs.openjdk.java.net/browse/CODETOOLS-7901322 // // The real world field layout is much more complicated. There are three kinds of fields diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index 45381365f1e52..1e02638591f8b 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -22,8 +22,8 @@ import org.apache.spark.annotation.Since /** * A class for tracking the statistics of a set of numbers (count, mean and variance) in a * numerically robust way. Includes support for merging two StatCounters. Based on Welford - * and Chan's [[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance algorithms]] - * for running variance. + * and Chan's + * algorithms for running variance. * * @constructor Initialize the StatCounter with the given values. */ diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala index d4e0ad93b966a..b1217980faf1f 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala @@ -24,4 +24,8 @@ private[spark] case class ThreadStackTrace( threadId: Long, threadName: String, threadState: Thread.State, - stackTrace: String) + stackTrace: String, + blockedByThreadId: Option[Long], + blockedByLock: String, + holdingLocks: Seq[String]) + diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index d093e7bfc3dac..1aa4456ed01b4 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ -import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor} import scala.concurrent.duration.Duration import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal @@ -180,39 +180,30 @@ private[spark] object ThreadUtils { // scalastyle:off awaitresult /** - * Preferred alternative to [[Await.result()]]. This method wraps and re-throws any exceptions - * thrown by the underlying [[Await]] call, ensuring that this thread's stack trace appears in - * logs. - */ - @throws(classOf[SparkException]) - def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { - try { - Await.result(awaitable, atMost) - // scalastyle:on awaitresult - } catch { - case NonFatal(t) => - throw new SparkException("Exception thrown in awaitResult: ", t) - } - } - - /** - * Calls [[Awaitable.result]] directly to avoid using `ForkJoinPool`'s `BlockingContext`, wraps - * and re-throws any exceptions with nice stack track. + * Preferred alternative to `Await.result()`. + * + * This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring + * that this thread's stack trace appears in logs. * - * Codes running in the user's thread may be in a thread of Scala ForkJoinPool. As concurrent - * executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this method - * basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s + * `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool. + * As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this + * method basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's + * hard to debug when [[ThreadLocal]]s leak to other tasks. */ @throws(classOf[SparkException]) - def awaitResultInForkJoinSafely[T](awaitable: Awaitable[T], atMost: Duration): T = { + def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { try { // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. // See SPARK-13747. val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - awaitable.result(Duration.Inf)(awaitPermission) + awaitable.result(atMost)(awaitPermission) } catch { - case NonFatal(t) => + // TimeoutException is thrown in the current thread, so not need to warp the exception. + case NonFatal(t) if !t.isInstanceOf[TimeoutException] => throw new SparkException("Exception thrown in awaitResult: ", t) } } + // scalastyle:on awaitresult } diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index f0b68f0cb7e29..27922b31949b6 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy * * Note: "runUninterruptibly" should be called only in `this` thread. */ -private[spark] class UninterruptibleThread(name: String) extends Thread(name) { +private[spark] class UninterruptibleThread( + target: Runnable, + name: String) extends Thread(target, name) { + + def this(name: String) { + this(null, name) + } /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new Object diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6027b07c0fee8..edfe229792323 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,8 @@ package org.apache.spark.util import java.io._ -import java.lang.management.ManagementFactory +import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels @@ -38,7 +39,9 @@ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import scala.util.matching.Regex +import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses @@ -54,7 +57,7 @@ import org.slf4j.Logger import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.{DYN_ALLOCATION_INITIAL_EXECUTORS, DYN_ALLOCATION_MIN_EXECUTORS, EXECUTOR_INSTANCES} +import org.apache.spark.internal.config._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.util.logging.RollingFileAppender @@ -236,9 +239,11 @@ private[spark] object Utils extends Logging { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { + val originalPosition = bb.position() val bbval = new Array[Byte](bb.remaining()) bb.get(bbval) out.write(bbval) + bb.position(originalPosition) } } @@ -249,9 +254,11 @@ private[spark] object Utils extends Logging { if (bb.hasArray) { out.write(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) } else { + val originalPosition = bb.position() val bbval = new Array[Byte](bb.remaining()) bb.get(bbval) out.write(bbval) + bb.position(originalPosition) } } @@ -733,7 +740,11 @@ private[spark] object Utils extends Logging { * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf)(0) + getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val configuredLocalDirs = getConfiguredLocalDirs(conf) + throw new IOException( + s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } } private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { @@ -926,12 +937,13 @@ private[spark] object Utils extends Logging { customHostname.getOrElse(InetAddresses.toUriString(localIpAddress)) } - def checkHost(host: String, message: String = "") { - assert(host.indexOf(':') == -1, message) + def checkHost(host: String) { + assert(host != null && host.indexOf(':') == -1, s"Expected hostname (not IP) but got $host") } - def checkHostPort(hostPort: String, message: String = "") { - assert(hostPort.indexOf(':') != -1, message) + def checkHostPort(hostPort: String) { + assert(hostPort != null && hostPort.indexOf(':') != -1, + s"Expected host and port but got $hostPort") } // Typically, this will be of order of number of nodes in cluster @@ -1104,26 +1116,39 @@ private[spark] object Utils extends Logging { /** * Convert a quantity in bytes to a human-readable string such as "4.0 MB". */ - def bytesToString(size: Long): String = { + def bytesToString(size: Long): String = bytesToString(BigInt(size)) + + def bytesToString(size: BigInt): String = { + val EB = 1L << 60 + val PB = 1L << 50 val TB = 1L << 40 val GB = 1L << 30 val MB = 1L << 20 val KB = 1L << 10 - val (value, unit) = { - if (size >= 2*TB) { - (size.asInstanceOf[Double] / TB, "TB") - } else if (size >= 2*GB) { - (size.asInstanceOf[Double] / GB, "GB") - } else if (size >= 2*MB) { - (size.asInstanceOf[Double] / MB, "MB") - } else if (size >= 2*KB) { - (size.asInstanceOf[Double] / KB, "KB") - } else { - (size.asInstanceOf[Double], "B") + if (size >= BigInt(1L << 11) * EB) { + // The number is too large, show it in scientific notation. + BigDecimal(size, new MathContext(3, RoundingMode.HALF_UP)).toString() + " B" + } else { + val (value, unit) = { + if (size >= 2 * EB) { + (BigDecimal(size) / EB, "EB") + } else if (size >= 2 * PB) { + (BigDecimal(size) / PB, "PB") + } else if (size >= 2 * TB) { + (BigDecimal(size) / TB, "TB") + } else if (size >= 2 * GB) { + (BigDecimal(size) / GB, "GB") + } else if (size >= 2 * MB) { + (BigDecimal(size) / MB, "MB") + } else if (size >= 2 * KB) { + (BigDecimal(size) / KB, "KB") + } else { + (BigDecimal(size), "B") + } } + "%.1f %s".formatLocal(Locale.US, value, unit) } - "%.1f %s".formatLocal(Locale.US, value, unit) } /** @@ -1248,7 +1273,7 @@ private[spark] object Utils extends Logging { val currentThreadName = Thread.currentThread().getName if (sc != null) { logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t) - sc.stop() + sc.stopInNewThread() } if (!NonFatal(t)) { logError(s"throw uncaught fatal error in thread $currentThreadName", t) @@ -1418,8 +1443,12 @@ private[spark] object Utils extends Logging { } callStack(0) = ste.toString // Put last Spark method on top of the stack trace. } else { - firstUserLine = ste.getLineNumber - firstUserFile = ste.getFileName + if (ste.getFileName != null) { + firstUserFile = ste.getFileName + if (ste.getLineNumber >= 0) { + firstUserLine = ste.getLineNumber + } + } callStack += ste.toString insideSpark = false } @@ -1479,10 +1508,11 @@ private[spark] object Utils extends Logging { /** Return uncompressed file length of a compressed file. */ private def getCompressedFileLength(file: File): Long = { + var gzInputStream: GZIPInputStream = null try { // Uncompress .gz file to determine file size. var fileSize = 0L - val gzInputStream = new GZIPInputStream(new FileInputStream(file)) + gzInputStream = new GZIPInputStream(new FileInputStream(file)) val bufSize = 1024 val buf = new Array[Byte](bufSize) var numBytes = ByteStreams.read(gzInputStream, buf, 0, bufSize) @@ -1495,6 +1525,10 @@ private[spark] object Utils extends Logging { case e: Throwable => logError(s"Cannot get file length of ${file}", e) throw e + } finally { + if (gzInputStream != null) { + gzInputStream.close() + } } } @@ -1668,8 +1702,8 @@ private[spark] object Utils extends Logging { } /** - * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared - * according to semantics where NaN == NaN and NaN > any non-NaN double. + * NaN-safe version of `java.lang.Double.compare()` which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN is greater than any non-NaN double. */ def nanSafeCompareDoubles(x: Double, y: Double): Int = { val xIsNan: Boolean = java.lang.Double.isNaN(x) @@ -1682,8 +1716,8 @@ private[spark] object Utils extends Logging { } /** - * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared - * according to semantics where NaN == NaN and NaN > any non-NaN float. + * NaN-safe version of `java.lang.Float.compare()` which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN is greater than any non-NaN float. */ def nanSafeCompareFloats(x: Float, y: Float): Int = { val xIsNan: Boolean = java.lang.Float.isNaN(x) @@ -1868,20 +1902,17 @@ private[spark] object Utils extends Logging { def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { // Politely destroy first process.destroy() - - if (waitForProcess(process, timeoutMs)) { + if (process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) { // Successful exit Option(process.exitValue()) } else { - // Java 8 added a new API which will more forcibly kill the process. Use that if available. try { - classOf[Process].getMethod("destroyForcibly").invoke(process) + process.destroyForcibly() } catch { - case _: NoSuchMethodException => return None // Not available; give up case NonFatal(e) => logWarning("Exception when attempting to kill process", e) } // Wait, again, although this really should return almost immediately - if (waitForProcess(process, timeoutMs)) { + if (process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) { Option(process.exitValue()) } else { logWarning("Timed out waiting to forcibly kill process") @@ -1890,45 +1921,12 @@ private[spark] object Utils extends Logging { } } - /** - * Wait for a process to terminate for at most the specified duration. - * - * @return whether the process actually terminated before the given timeout. - */ - def waitForProcess(process: Process, timeoutMs: Long): Boolean = { - try { - // Use Java 8 method if available - classOf[Process].getMethod("waitFor", java.lang.Long.TYPE, classOf[TimeUnit]) - .invoke(process, timeoutMs.asInstanceOf[java.lang.Long], TimeUnit.MILLISECONDS) - .asInstanceOf[Boolean] - } catch { - case _: NoSuchMethodException => - // Otherwise implement it manually - var terminated = false - val startTime = System.currentTimeMillis - while (!terminated) { - try { - process.exitValue() - terminated = true - } catch { - case e: IllegalThreadStateException => - // Process not terminated yet - if (System.currentTimeMillis - startTime > timeoutMs) { - return false - } - Thread.sleep(100) - } - } - true - } - } - /** * Return the stderr of a process after waiting for the process to terminate. * If the process does not terminate within the specified timeout, return None. */ def getStderr(process: Process, timeoutMs: Long): Option[String] = { - val terminated = Utils.waitForProcess(process, timeoutMs) + val terminated = process.waitFor(timeoutMs, TimeUnit.MILLISECONDS) if (terminated) { Some(Source.fromInputStream(process.getErrorStream).getLines().mkString("\n")) } else { @@ -2011,7 +2009,7 @@ private[spark] object Utils extends Logging { if (paths == null || paths.trim.isEmpty) { "" } else { - paths.split(",").map { p => Utils.resolveURI(p) }.mkString(",") + paths.split(",").filter(_.trim.nonEmpty).map { p => Utils.resolveURI(p) }.mkString(",") } } @@ -2051,6 +2049,20 @@ private[spark] object Utils extends Logging { path } + /** + * Updates Spark config with properties from a set of Properties. + * Provided properties have the highest priority. + */ + def updateSparkConfigFromProperties( + conf: SparkConf, + properties: Map[String, String]) : Unit = { + properties.filter { case (k, v) => + k.startsWith("spark.") + }.foreach { case (k, v) => + conf.set(k, v) + } + } + /** Load properties present in the given file. */ def getPropertiesFromFile(filename: String): Map[String, String] = { val file = new File(filename) @@ -2096,18 +2108,62 @@ private[spark] object Utils extends Logging { } } + private implicit class Lock(lock: LockInfo) { + def lockString: String = { + lock match { + case monitor: MonitorInfo => + s"Monitor(${lock.getClassName}@${lock.getIdentityHashCode}})" + case _ => + s"Lock(${lock.getClassName}@${lock.getIdentityHashCode}})" + } + } + } + /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */ def getThreadDump(): Array[ThreadStackTrace] = { // We need to filter out null values here because dumpAllThreads() may return null array // elements for threads that are dead / don't exist. val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null) - threadInfos.sortBy(_.getThreadId).map { case threadInfo => - val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n") - ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName, - threadInfo.getThreadState, stackTrace) + threadInfos.sortBy(_.getThreadId).map(threadInfoToThreadStackTrace) + } + + def getThreadDumpForThread(threadId: Long): Option[ThreadStackTrace] = { + if (threadId <= 0) { + None + } else { + // The Int.MaxValue here requests the entire untruncated stack trace of the thread: + val threadInfo = + Option(ManagementFactory.getThreadMXBean.getThreadInfo(threadId, Int.MaxValue)) + threadInfo.map(threadInfoToThreadStackTrace) } } + private def threadInfoToThreadStackTrace(threadInfo: ThreadInfo): ThreadStackTrace = { + val monitors = threadInfo.getLockedMonitors.map(m => m.getLockedStackFrame -> m).toMap + val stackTrace = threadInfo.getStackTrace.map { frame => + monitors.get(frame) match { + case Some(monitor) => + monitor.getLockedStackFrame.toString + s" => holding ${monitor.lockString}" + case None => + frame.toString + } + }.mkString("\n") + + // use a set to dedup re-entrant locks that are held at multiple places + val heldLocks = + (threadInfo.getLockedSynchronizers ++ threadInfo.getLockedMonitors).map(_.lockString).toSet + + ThreadStackTrace( + threadId = threadInfo.getThreadId, + threadName = threadInfo.getThreadName, + threadState = threadInfo.getThreadState, + stackTrace = stackTrace, + blockedByThreadId = + if (threadInfo.getLockOwnerId < 0) None else Some(threadInfo.getLockOwnerId), + blockedByLock = Option(threadInfo.getLockInfo).map(_.lockString).getOrElse(""), + holdingLocks = heldLocks.toSeq) + } + /** * Convert all spark properties set in the given SparkConf to a sequence of java options. */ @@ -2130,6 +2186,14 @@ private[spark] object Utils extends Logging { } } + /** + * Returns the user port to try when trying to bind a service. Handles wrapping and skipping + * privileged ports. + */ + def userPort(base: Int, offset: Int): Int = { + (base + offset - 1024) % (65536 - 1024) + 1024 + } + /** * Attempt to start a service on the given port, or fail after a number of attempts. * Each subsequent attempt uses 1 + the port used in the previous attempt (unless the port is 0). @@ -2157,8 +2221,7 @@ private[spark] object Utils extends Logging { val tryPort = if (startPort == 0) { startPort } else { - // If the new port wraps around, do not try a privilege port - ((startPort + offset - 1024) % (65536 - 1024)) + 1024 + userPort(startPort, offset) } try { val (service, port) = startService(tryPort) @@ -2167,17 +2230,32 @@ private[spark] object Utils extends Logging { } catch { case e: Exception if isBindCollision(e) => if (offset >= maxRetries) { - val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " + - s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + - s"the appropriate port for the service$serviceString (for example spark.ui.port " + - s"for SparkUI) to an available port or increasing spark.port.maxRetries." + val exceptionMessage = if (startPort == 0) { + s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries (on a random free port)! " + + s"Consider explicitly setting the appropriate binding address for " + + s"the service$serviceString (for example spark.driver.bindAddress " + + s"for SparkDriver) to the correct binding address." + } else { + s"${e.getMessage}: Service$serviceString failed after " + + s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " + + s"the appropriate port for the service$serviceString (for example spark.ui.port " + + s"for SparkUI) to an available port or increasing spark.port.maxRetries." + } val exception = new BindException(exceptionMessage) // restore original stack trace exception.setStackTrace(e.getStackTrace) throw exception } - logWarning(s"Service$serviceString could not bind on port $tryPort. " + - s"Attempting port ${tryPort + 1}.") + if (startPort == 0) { + // As startPort 0 is for a random free port, it is most possibly binding address is + // not correct. + logWarning(s"Service$serviceString could not bind on a random free port. " + + "You may check whether configuring an appropriate binding address.") + } else { + logWarning(s"Service$serviceString could not bind on port $tryPort. " + + s"Attempting port ${tryPort + 1}.") + } } } // Should never happen @@ -2196,6 +2274,9 @@ private[spark] object Utils extends Logging { isBindCollision(e.getCause) case e: MultiException => e.getThrowables.asScala.exists(isBindCollision) + case e: NativeIoException => + (e.getMessage != null && e.getMessage.startsWith("bind() failed: ")) || + isBindCollision(e.getCause) case e: Exception => isBindCollision(e.getCause) case _ => false } @@ -2306,8 +2387,9 @@ private[spark] object Utils extends Logging { * A spark url (`spark://host:port`) is a special URI that its scheme is `spark` and only contains * host and port. * - * @throws SparkException if `sparkUrl` is invalid. + * @throws org.apache.spark.SparkException if sparkUrl is invalid. */ + @throws(classOf[SparkException]) def extractHostPortFromSparkUrl(sparkUrl: String): (String, Int) = { try { val uri = new java.net.URI(sparkUrl) @@ -2507,16 +2589,71 @@ private[spark] object Utils extends Logging { sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten } } + + private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)" + + /** + * Redact the sensitive values in the given map. If a map key matches the redaction pattern then + * its value is replaced with a dummy text. + */ + def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { + val redactionPattern = conf.get(SECRET_REDACTION_PATTERN) + redact(redactionPattern, kvs) + } + + /** + * Redact the sensitive information in the given string. + */ + def redact(conf: SparkConf, text: String): String = { + if (text == null || text.isEmpty || !conf.contains(STRING_REDACTION_PATTERN)) return text + val regex = conf.get(STRING_REDACTION_PATTERN).get + regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT) + } + + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { + // If the sensitive information regex matches with either the key or the value, redact the value + // While the original intent was to only redact the value if the key matched with the regex, + // we've found that especially in verbose mode, the value of the property may contain sensitive + // information like so: + // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \ + // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ... + // + // And, in such cases, simply searching for the sensitive information regex in the key name is + // not sufficient. The values themselves have to be searched as well and redacted if matched. + // This does mean we may be accounting more false positives - for example, if the value of an + // arbitrary property contained the term 'password', we may redact the value from the UI and + // logs. In order to work around it, user would have to make the spark.redaction.regex property + // more specific. + kvs.map { case (key, value) => + redactionPattern.findFirstIn(key) + .orElse(redactionPattern.findFirstIn(value)) + .map { _ => (key, REDACTION_REPLACEMENT_TEXT) } + .getOrElse((key, value)) + } + } + + /** + * Looks up the redaction regex from within the key value pairs and uses it to redact the rest + * of the key value pairs. No care is taken to make sure the redaction property itself is not + * redacted. So theoretically, the property itself could be configured to redact its own value + * when printing. + */ + def redact(kvs: Map[String, String]): Seq[(String, String)] = { + val redactionPattern = kvs.getOrElse( + SECRET_REDACTION_PATTERN.key, + SECRET_REDACTION_PATTERN.defaultValueString + ).r + redact(redactionPattern, kvs.toArray) + } + } private[util] object CallerContext extends Logging { val callerContextSupported: Boolean = { SparkHadoopUtil.get.conf.getBoolean("hadoop.caller.context.enabled", false) && { try { - // scalastyle:off classforname - Class.forName("org.apache.hadoop.ipc.CallerContext") - Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname + Utils.classForName("org.apache.hadoop.ipc.CallerContext") + Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") true } catch { case _: ClassNotFoundException => @@ -2541,6 +2678,7 @@ private[util] object CallerContext extends Logging { * @param from who sets up the caller context (TASK, CLIENT, APPMASTER) * * The parameters below are optional: + * @param upstreamCallerContext caller context the upstream application passes in * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to * @param jobId id of the job this task belongs to @@ -2550,26 +2688,38 @@ private[util] object CallerContext extends Logging { * @param taskAttemptNumber task attempt id */ private[spark] class CallerContext( - from: String, - appId: Option[String] = None, - appAttemptId: Option[String] = None, - jobId: Option[Int] = None, - stageId: Option[Int] = None, - stageAttemptId: Option[Int] = None, - taskId: Option[Long] = None, - taskAttemptNumber: Option[Int] = None) extends Logging { - - val appIdStr = if (appId.isDefined) s"_${appId.get}" else "" - val appAttemptIdStr = if (appAttemptId.isDefined) s"_${appAttemptId.get}" else "" - val jobIdStr = if (jobId.isDefined) s"_JId_${jobId.get}" else "" - val stageIdStr = if (stageId.isDefined) s"_SId_${stageId.get}" else "" - val stageAttemptIdStr = if (stageAttemptId.isDefined) s"_${stageAttemptId.get}" else "" - val taskIdStr = if (taskId.isDefined) s"_TId_${taskId.get}" else "" - val taskAttemptNumberStr = - if (taskAttemptNumber.isDefined) s"_${taskAttemptNumber.get}" else "" - - val context = "SPARK_" + from + appIdStr + appAttemptIdStr + - jobIdStr + stageIdStr + stageAttemptIdStr + taskIdStr + taskAttemptNumberStr + from: String, + upstreamCallerContext: Option[String] = None, + appId: Option[String] = None, + appAttemptId: Option[String] = None, + jobId: Option[Int] = None, + stageId: Option[Int] = None, + stageAttemptId: Option[Int] = None, + taskId: Option[Long] = None, + taskAttemptNumber: Option[Int] = None) extends Logging { + + private val context = prepareContext("SPARK_" + + from + + appId.map("_" + _).getOrElse("") + + appAttemptId.map("_" + _).getOrElse("") + + jobId.map("_JId_" + _).getOrElse("") + + stageId.map("_SId_" + _).getOrElse("") + + stageAttemptId.map("_" + _).getOrElse("") + + taskId.map("_TId_" + _).getOrElse("") + + taskAttemptNumber.map("_" + _).getOrElse("") + + upstreamCallerContext.map("_" + _).getOrElse("")) + + private def prepareContext(context: String): String = { + // The default max size of Hadoop caller context is 128 + lazy val len = SparkHadoopUtil.get.conf.getInt("hadoop.caller.context.max.size", 128) + if (context == null || context.length <= len) { + context + } else { + val finalContext = context.substring(0, len) + logWarning(s"Truncated Spark caller context from $context to $finalContext") + finalContext + } + } /** * Set up the caller context [[context]] by invoking Hadoop CallerContext API of @@ -2578,10 +2728,8 @@ private[spark] class CallerContext( def setCurrentContext(): Unit = { if (CallerContext.callerContextSupported) { try { - // scalastyle:off classforname - val callerContext = Class.forName("org.apache.hadoop.ipc.CallerContext") - val builder = Class.forName("org.apache.hadoop.ipc.CallerContext$Builder") - // scalastyle:on classforname + val callerContext = Utils.classForName("org.apache.hadoop.ipc.CallerContext") + val builder = Utils.classForName("org.apache.hadoop.ipc.CallerContext$Builder") val builderInst = builder.getConstructor(classOf[String]).newInstance(context) val hdfsContext = builder.getMethod("build").invoke(builderInst) callerContext.getMethod("setCurrent", callerContext).invoke(null, hdfsContext) diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index 6b74a29aceda9..bcb95b416dd25 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -140,16 +140,16 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k.equals(curKey)) { - val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) - data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] - return newValue - } else if (curKey.eq(null)) { + if (curKey.eq(null)) { val newValue = updateFunc(false, null.asInstanceOf[V]) data(2 * pos) = k data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] incrementSize() return newValue + } else if (k.eq(curKey) || k.equals(curKey)) { + val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue } else { val delta = i pos = (pos + delta) & mask diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 948cc3b099b18..8aafda5e45d52 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -192,12 +192,19 @@ class ExternalAppendOnlyMap[K, V, C]( * It will be called by TaskMemoryManager when there is not enough memory for the task. */ override protected[this] def forceSpill(): Boolean = { - assert(readingIterator != null) - val isSpilled = readingIterator.spill() - if (isSpilled) { - currentMap = null + if (readingIterator != null) { + val isSpilled = readingIterator.spill() + if (isSpilled) { + currentMap = null + } + isSpilled + } else if (currentMap.size > 0) { + spill(currentMap) + currentMap = new SizeTrackingAppendOnlyMap[K, C] + true + } else { + false } - isSpilled } /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala new file mode 100644 index 0000000000000..6e57c3c5bee8c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/MedianHeap.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import scala.collection.mutable.PriorityQueue + +/** + * MedianHeap is designed to be used to quickly track the median of a group of numbers + * that may contain duplicates. Inserting a new number has O(log n) time complexity and + * determining the median has O(1) time complexity. + * The basic idea is to maintain two heaps: a smallerHalf and a largerHalf. The smallerHalf + * stores the smaller half of all numbers while the largerHalf stores the larger half. + * The sizes of two heaps need to be balanced each time when a new number is inserted so + * that their sizes will not be different by more than 1. Therefore each time when + * findMedian() is called we check if two heaps have the same size. If they do, we should + * return the average of the two top values of heaps. Otherwise we return the top of the + * heap which has one more element. + */ +private[spark] class MedianHeap(implicit val ord: Ordering[Double]) { + + /** + * Stores all the numbers less than the current median in a smallerHalf, + * i.e median is the maximum, at the root. + */ + private[this] var smallerHalf = PriorityQueue.empty[Double](ord) + + /** + * Stores all the numbers greater than the current median in a largerHalf, + * i.e median is the minimum, at the root. + */ + private[this] var largerHalf = PriorityQueue.empty[Double](ord.reverse) + + def isEmpty(): Boolean = { + smallerHalf.isEmpty && largerHalf.isEmpty + } + + def size(): Int = { + smallerHalf.size + largerHalf.size + } + + def insert(x: Double): Unit = { + // If both heaps are empty, we arbitrarily insert it into a heap, let's say, the largerHalf. + if (isEmpty) { + largerHalf.enqueue(x) + } else { + // If the number is larger than current median, it should be inserted into largerHalf, + // otherwise smallerHalf. + if (x > median) { + largerHalf.enqueue(x) + } else { + smallerHalf.enqueue(x) + } + } + rebalance() + } + + private[this] def rebalance(): Unit = { + if (largerHalf.size - smallerHalf.size > 1) { + smallerHalf.enqueue(largerHalf.dequeue()) + } + if (smallerHalf.size - largerHalf.size > 1) { + largerHalf.enqueue(smallerHalf.dequeue) + } + } + + def median: Double = { + if (isEmpty) { + throw new NoSuchElementException("MedianHeap is empty.") + } + if (largerHalf.size == smallerHalf.size) { + (largerHalf.head + smallerHalf.head) / 2.0 + } else if (largerHalf.size > smallerHalf.size) { + largerHalf.head + } else { + smallerHalf.head + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 0f6a425e3db9a..60f6f537c1d54 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -48,7 +48,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( require(initialCapacity <= OpenHashSet.MAX_CAPACITY, s"Can't make capacity bigger than ${OpenHashSet.MAX_CAPACITY} elements") - require(initialCapacity >= 1, "Invalid initial capacity") + require(initialCapacity >= 0, "Invalid initial capacity") require(loadFactor < 1.0, "Load factor must be less than 1.0") require(loadFactor > 0.0, "Load factor must be greater than 0.0") @@ -271,8 +271,12 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() private def nextPowerOf2(n: Int): Int = { - val highBit = Integer.highestOneBit(n) - if (highBit == n) n else highBit << 1 + if (n == 0) { + 1 + } else { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 + } } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index 89b0874e3865a..2f905c8af0f63 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -86,7 +86,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Copy this buffer into a new ByteBuffer. + * Convert this buffer to a ByteBuffer. If this buffer is backed by a single chunk, its underlying + * data will not be copied. Instead, it will be duplicated. If this buffer is backed by multiple + * chunks, the data underlying this buffer will be copied into a new byte buffer. As a result, it + * is suggested to use this method only if the caller does not need to manage the memory + * underlying this buffer. * * @throws UnsupportedOperationException if this buffer's size exceeds the max ByteBuffer size. */ @@ -132,10 +136,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { } /** - * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that - * might cause errors if one attempts to read from the unmapped buffer, but it's better than - * waiting for the GC to find it because that could lead to huge numbers of open files. There's - * unfortunately no standard API to do this. + * Attempt to clean up any ByteBuffer in this ChunkedByteBuffer which is direct or memory-mapped. + * See [[StorageUtils.dispose]] for more information. */ def dispose(): Unit = { if (!disposed) { @@ -143,15 +145,16 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { disposed = true } } + } /** * Reads data from a ChunkedByteBuffer. * - * @param dispose if true, [[ChunkedByteBuffer.dispose()]] will be called at the end of the stream + * @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream * in order to close any memory-mapped files which back the buffer. */ -private class ChunkedByteBufferInputStream( +private[spark] class ChunkedByteBufferInputStream( var chunkedByteBuffer: ChunkedByteBuffer, dispose: Boolean) extends InputStream { diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index 5c4238c0381a1..1f263df57c857 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -18,7 +18,7 @@ package org.apache.spark.util.logging import java.text.SimpleDateFormat -import java.util.Calendar +import java.util.{Calendar, Locale} import org.apache.spark.internal.Logging @@ -59,7 +59,7 @@ private[spark] class TimeBasedRollingPolicy( } @volatile private var nextRolloverTime = calculateNextRolloverTime() - private val formatter = new SimpleDateFormat(rollingFileSuffixPattern) + private val formatter = new SimpleDateFormat(rollingFileSuffixPattern, Locale.US) /** Should rollover if current time has exceeded next rollover time */ def shouldRollover(bytesToBeWritten: Long): Boolean = { @@ -109,7 +109,7 @@ private[spark] class SizeBasedRollingPolicy( } @volatile private var bytesWrittenSinceRollover = 0L - val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS") + val formatter = new SimpleDateFormat("--yyyy-MM-dd--HH-mm-ss--SSSS", Locale.US) /** Should rollover if the next set of bytes is going to exceed the size limit */ def shouldRollover(bytesToBeWritten: Long): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 8c67364ef1a05..ea99a7e5b4847 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -19,7 +19,6 @@ package org.apache.spark.util.random import java.util.Random -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.commons.math3.distribution.PoissonDistribution diff --git a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala index f98932a470165..a7e0075debedb 100644 --- a/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala @@ -56,28 +56,33 @@ private[spark] object SamplingUtils { val rand = new XORShiftRandom(seed) while (input.hasNext) { val item = input.next() + l += 1 + // There are k elements in the reservoir, and the l-th element has been + // consumed. It should be chosen with probability k/l. The expression + // below is a random long chosen uniformly from [0,l) val replacementIndex = (rand.nextDouble() * l).toLong if (replacementIndex < k) { reservoir(replacementIndex.toInt) = item } - l += 1 } (reservoir, l) } } /** - * Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of - * the time. + * Returns a sampling rate that guarantees a sample of size greater than or equal to + * sampleSizeLowerBound 99.99% of the time. * * How the sampling rate is determined: + * * Let p = num / total, where num is the sample size and total is the total number of - * datapoints in the RDD. We're trying to compute q > p such that + * datapoints in the RDD. We're trying to compute q {@literal >} p such that * - when sampling with replacement, we're drawing each datapoint with prob_i ~ Pois(q), - * where we want to guarantee Pr[s < num] < 0.0001 for s = sum(prob_i for i from 0 to total), - * i.e. the failure rate of not having a sufficiently large sample < 0.0001. + * where we want to guarantee + * Pr[s {@literal <} num] {@literal <} 0.0001 for s = sum(prob_i for i from 0 to total), + * i.e. the failure rate of not having a sufficiently large sample {@literal <} 0.0001. * Setting q = p + 5 * sqrt(p/total) is sufficient to guarantee 0.9999 success rate for - * num > 12, but we need a slightly larger q (9 empirically determined). + * num {@literal >} 12, but we need a slightly larger q (9 empirically determined). * - when sampling without replacement, we're drawing each datapoint with prob_i * ~ Binomial(total, fraction) and our choice of q guarantees 1-delta, or 0.9999 success * rate, where success rate is defined the same as in sampling with replacement. @@ -108,14 +113,14 @@ private[spark] object SamplingUtils { private[spark] object PoissonBounds { /** - * Returns a lambda such that Pr[X > s] is very small, where X ~ Pois(lambda). + * Returns a lambda such that Pr[X {@literal >} s] is very small, where X ~ Pois(lambda). */ def getLowerBound(s: Double): Double = { math.max(s - numStd(s) * math.sqrt(s), 1e-15) } /** - * Returns a lambda such that Pr[X < s] is very small, where X ~ Pois(lambda). + * Returns a lambda such that Pr[X {@literal <} s] is very small, where X ~ Pois(lambda). * * @param s sample size */ diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala index 67822749112c6..ce46fc8f201be 100644 --- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala @@ -35,13 +35,14 @@ import org.apache.spark.rdd.RDD * high probability. This is achieved by maintaining a waitlist of size O(log(s)), where s is the * desired sample size for each stratum. * - * Like in simple random sampling, we generate a random value for each item from the - * uniform distribution [0.0, 1.0]. All items with values <= min(values of items in the waitlist) - * are accepted into the sample instantly. The threshold for instant accept is designed so that - * s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by maintaining a - * waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size s by adding - * a portion of the waitlist to the set of items that are instantly accepted. The exact threshold - * is computed by sorting the values in the waitlist and picking the value at (s - numAccepted). + * Like in simple random sampling, we generate a random value for each item from the uniform + * distribution [0.0, 1.0]. All items with values less than or equal to min(values of items in the + * waitlist) are accepted into the sample instantly. The threshold for instant accept is designed + * so that s - numAccepted = O(sqrt(s)), where s is again the desired sample size. Thus, by + * maintaining a waitlist size = O(sqrt(s)), we will be able to create a sample of the exact size + * s by adding a portion of the waitlist to the set of items that are instantly accepted. The exact + * threshold is computed by sorting the values in the waitlist and picking the value at + * (s - numAccepted). * * Note that since we use the same seed for the RNG when computing the thresholds and the actual * sample, our computed thresholds are guaranteed to produce the desired sample size. @@ -160,12 +161,20 @@ private[spark] object StratifiedSamplingUtils extends Logging { * * To do so, we compute sampleSize = math.ceil(size * samplingRate) for each stratum and compare * it to the number of items that were accepted instantly and the number of items in the waitlist - * for that stratum. Most of the time, numAccepted <= sampleSize <= (numAccepted + numWaitlisted), + * for that stratum. + * + * Most of the time, + * {{{ + * numAccepted <= sampleSize <= (numAccepted + numWaitlisted) + * }}} * which means we need to sort the elements in the waitlist by their associated values in order - * to find the value T s.t. |{elements in the stratum whose associated values <= T}| = sampleSize. - * Note that all elements in the waitlist have values >= bound for instant accept, so a T value - * in the waitlist range would allow all elements that were instantly accepted on the first pass - * to be included in the sample. + * to find the value T s.t. + * {{{ + * |{elements in the stratum whose associated values <= T}| = sampleSize + * }}}. + * Note that all elements in the waitlist have values greater than or equal to bound for instant + * accept, so a T value in the waitlist range would allow all elements that were instantly + * accepted on the first pass to be included in the sample. */ def computeThresholdByKey[K](finalResult: Map[K, AcceptanceResult], fractions: Map[K, Double]): Map[K, Double] = { diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e88ab68e..51feccfb8342a 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -55,14 +55,16 @@ class TaskCompletionListenerException( extends RuntimeException { override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + - previousError.map { e => + val listenerErrorMessage = + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + val previousErrorMessage = previousError.map { e => "\n\nPrevious exception in task: " + e.getMessage + "\n" + e.getStackTrace.mkString("\t", "\n\t", "") }.getOrElse("") + listenerErrorMessage + previousErrorMessage } } diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java index 7fe452a48d89b..a6589d2898144 100644 --- a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java +++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java @@ -20,14 +20,11 @@ import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; -import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; import org.apache.spark.rdd.JdbcRDD; import org.junit.After; import org.junit.Assert; @@ -89,30 +86,13 @@ public void tearDown() throws SQLException { public void testJavaJdbcRDD() throws Exception { JavaRDD rdd = JdbcRDD.create( sc, - new JdbcRDD.ConnectionFactory() { - @Override - public Connection getConnection() throws SQLException { - return DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb"); - } - }, + () -> DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb"), "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", 1, 100, 1, - new Function() { - @Override - public Integer call(ResultSet r) throws Exception { - return r.getInt(1); - } - } + r -> r.getInt(1) ).cache(); Assert.assertEquals(100, rdd.count()); - Assert.assertEquals( - Integer.valueOf(10100), - rdd.reduce(new Function2() { - @Override - public Integer call(Integer i1, Integer i2) { - return i1 + i2; - } - })); + Assert.assertEquals(Integer.valueOf(10100), rdd.reduce((i1, i2) -> i1 + i2)); } } diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 682d98867b456..0c77123740852 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -27,8 +27,10 @@ import org.slf4j.LoggerFactory; import org.slf4j.bridge.SLF4JBridgeHandler; import static org.junit.Assert.*; +import static org.junit.Assume.*; import org.apache.spark.internal.config.package$; +import org.apache.spark.util.Utils; /** * These tests require the Spark assembly to be built before they can be run. @@ -155,6 +157,10 @@ public void testRedirectToLog() throws Exception { @Test public void testChildProcLauncher() throws Exception { + // This test is failed on Windows due to the failure of initiating executors + // by the path length limitation. See SPARK-18718. + assumeTrue(!Utils.isWindows()); + SparkSubmitOptionParser opts = new SparkSubmitOptionParser(); Map env = new HashMap<>(); env.put("SPARK_PRINT_LAUNCH_COMMAND", "1"); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index ad755529dec64..f53bc0b02bbfa 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -109,6 +109,41 @@ public void cooperativeSpilling() { Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory()); } + @Test + public void cooperativeSpilling2() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(100); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0); + + TestMemoryConsumer c1 = new TestMemoryConsumer(manager); + TestMemoryConsumer c2 = new TestMemoryConsumer(manager); + TestMemoryConsumer c3 = new TestMemoryConsumer(manager); + + c1.use(20); + Assert.assertEquals(20, c1.getUsed()); + c2.use(80); + Assert.assertEquals(80, c2.getUsed()); + c3.use(80); + Assert.assertEquals(20, c1.getUsed()); // c1: not spilled + Assert.assertEquals(0, c2.getUsed()); // c2: spilled as it has required size of memory + Assert.assertEquals(80, c3.getUsed()); + + c2.use(80); + Assert.assertEquals(20, c1.getUsed()); // c1: not spilled + Assert.assertEquals(0, c3.getUsed()); // c3: spilled as it has required size of memory + Assert.assertEquals(80, c2.getUsed()); + + c3.use(10); + Assert.assertEquals(0, c1.getUsed()); // c1: spilled as it has required size of memory + Assert.assertEquals(80, c2.getUsed()); // c2: not spilled as spilling c1 already satisfies c3 + Assert.assertEquals(10, c3.getUsed()); + + c1.free(0); + c2.free(80); + c3.free(10); + Assert.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + } + @Test public void shouldNotForceSpillingInDifferentModes() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a96cd82382e2c..24a55df84a240 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -26,18 +26,14 @@ import scala.Tuple2; import scala.Tuple2$; import scala.collection.Iterator; -import scala.runtime.AbstractFunction1; import com.google.common.collect.HashMultiset; import com.google.common.collect.Iterators; -import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.HashPartitioner; import org.apache.spark.ShuffleDependency; @@ -53,6 +49,7 @@ import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.*; @@ -77,7 +74,6 @@ public class UnsafeShuffleWriterSuite { final LinkedList spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); - final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -86,17 +82,6 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency shuffleDep; - private final class WrapStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); - } else { - return stream; - } - } - } - @After public void tearDown() { Utils.deleteRecursively(tempDir); @@ -121,53 +106,46 @@ public void setUp() throws IOException { memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + // Some tests will override this manager because they change the configuration. This is a + // default for tests that don't need a specific one. + SerializerManager manager = new SerializerManager(serializer, conf); + when(blockManager.serializerManager()).thenReturn(manager); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { Object[] args = invocationOnMock.getArguments(); - return new DiskBlockObjectWriter( (File) args[1], + blockManager.serializerManager(), (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] ); - } - }); + }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - doAnswer(new Answer() { - @Override - public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; - File tmp = (File) invocationOnMock.getArguments()[3]; - mergedOutputFile.delete(); - tmp.renameTo(mergedOutputFile); - return null; - } + doAnswer(invocationOnMock -> { + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; + File tmp = (File) invocationOnMock.getArguments()[3]; + mergedOutputFile.delete(); + tmp.renameTo(mergedOutputFile); + return null; }).when(shuffleBlockResolver) .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); - when(diskBlockManager.createTempShuffleBlock()).thenAnswer( - new Answer>() { - @Override - public Tuple2 answer( - InvocationOnMock invocationOnMock) throws Throwable { - TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } - }); + when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { + TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); + }); when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(serializer); @@ -201,9 +179,10 @@ private List> readRecordsFromFile() throws IOException { for (int i = 0; i < NUM_PARTITITONS; i++) { final long partitionSize = partitionSizesInMergedFile[i]; if (partitionSize > 0) { - InputStream in = new FileInputStream(mergedOutputFile); - ByteStreams.skipFully(in, startOffset); - in = new LimitedInputStream(in, partitionSize); + FileInputStream fin = new FileInputStream(mergedOutputFile); + fin.getChannel().position(startOffset); + InputStream in = new LimitedInputStream(fin, partitionSize); + in = blockManager.serializerManager().wrapForEncryption(in); if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } @@ -251,7 +230,7 @@ class BadRecords extends scala.collection.AbstractIterator writer = createWriter(true); - writer.write(Iterators.>emptyIterator()); + writer.write(Iterators.emptyIterator()); final Option mapStatus = writer.stop(true); assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); @@ -267,7 +246,7 @@ public void writeWithoutSpilling() throws Exception { // In this example, each partition should have exactly one record: final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < NUM_PARTITITONS; i++) { - dataToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2<>(i, i)); } final UnsafeShuffleWriter writer = createWriter(true); writer.write(dataToWrite.iterator()); @@ -294,18 +273,36 @@ public void writeWithoutSpilling() throws Exception { } private void testMergingSpills( - boolean transferToEnabled, - String compressionCodecName) throws IOException { + final boolean transferToEnabled, + String compressionCodecName, + boolean encrypt) throws Exception { if (compressionCodecName != null) { conf.set("spark.shuffle.compress", "true"); conf.set("spark.io.compression.codec", compressionCodecName); } else { conf.set("spark.shuffle.compress", "false"); } + conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt); + + SerializerManager manager; + if (encrypt) { + manager = new SerializerManager(serializer, conf, + Option.apply(CryptoStreamUtils.createKey(conf))); + } else { + manager = new SerializerManager(serializer, conf); + } + + when(blockManager.serializerManager()).thenReturn(manager); + testMergingSpills(transferToEnabled, encrypt); + } + + private void testMergingSpills( + boolean transferToEnabled, + boolean encrypted) throws IOException { final UnsafeShuffleWriter writer = createWriter(transferToEnabled); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { - dataToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2<>(i, i)); } writer.insertRecordIntoSorter(dataToWrite.get(0)); writer.insertRecordIntoSorter(dataToWrite.get(1)); @@ -324,6 +321,7 @@ private void testMergingSpills( for (long size: partitionSizesInMergedFile) { sumOfPartitionSizes += size; } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); @@ -338,42 +336,72 @@ private void testMergingSpills( @Test public void mergeSpillsWithTransferToAndLZF() throws Exception { - testMergingSpills(true, LZFCompressionCodec.class.getName()); + testMergingSpills(true, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZF() throws Exception { - testMergingSpills(false, LZFCompressionCodec.class.getName()); + testMergingSpills(false, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndLZ4() throws Exception { - testMergingSpills(true, LZ4CompressionCodec.class.getName()); + testMergingSpills(true, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZ4() throws Exception { - testMergingSpills(false, LZ4CompressionCodec.class.getName()); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndSnappy() throws Exception { - testMergingSpills(true, SnappyCompressionCodec.class.getName()); + testMergingSpills(true, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndSnappy() throws Exception { - testMergingSpills(false, SnappyCompressionCodec.class.getName()); + testMergingSpills(false, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndNoCompression() throws Exception { - testMergingSpills(true, null); + testMergingSpills(true, null, false); } @Test public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { - testMergingSpills(false, null); + testMergingSpills(false, null, false); + } + + @Test + public void mergeSpillsWithCompressionAndEncryption() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { + conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false"); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithEncryptionAndNoCompression() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, null, true); + } + + @Test + public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception { + testMergingSpills(false, null, true); } @Test @@ -383,7 +411,7 @@ public void writeEnoughDataToTriggerSpill() throws Exception { final ArrayList> dataToWrite = new ArrayList<>(); final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 10]; for (int i = 0; i < 10 + 1; i++) { - dataToWrite.add(new Tuple2(i, bigByteArray)); + dataToWrite.add(new Tuple2<>(i, bigByteArray)); } writer.write(dataToWrite.iterator()); assertEquals(2, spillFilesCreated.size()); @@ -417,7 +445,7 @@ private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exc final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { - dataToWrite.add(new Tuple2(i, i)); + dataToWrite.add(new Tuple2<>(i, i)); } writer.write(dataToWrite.iterator()); writer.stop(true); @@ -437,7 +465,7 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception final ArrayList> dataToWrite = new ArrayList<>(); final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)]; new Random(42).nextBytes(bytes); - dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(bytes))); + dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(bytes))); writer.write(dataToWrite.iterator()); writer.stop(true); assertEquals( @@ -450,15 +478,15 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - dataToWrite.add(new Tuple2(1, ByteBuffer.wrap(new byte[1]))); + dataToWrite.add(new Tuple2<>(1, ByteBuffer.wrap(new byte[1]))); // We should be able to write a record that's right _at_ the max record size final byte[] atMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes() - 4]; new Random(42).nextBytes(atMaxRecordSize); - dataToWrite.add(new Tuple2(2, ByteBuffer.wrap(atMaxRecordSize))); + dataToWrite.add(new Tuple2<>(2, ByteBuffer.wrap(atMaxRecordSize))); // Inserting a record that's larger than the max record size final byte[] exceedsMaxRecordSize = new byte[(int) taskMemoryManager.pageSizeBytes()]; new Random(42).nextBytes(exceedsMaxRecordSize); - dataToWrite.add(new Tuple2(3, ByteBuffer.wrap(exceedsMaxRecordSize))); + dataToWrite.add(new Tuple2<>(3, ByteBuffer.wrap(exceedsMaxRecordSize))); writer.write(dataToWrite.iterator()); writer.stop(true); assertEquals( @@ -470,10 +498,10 @@ public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception { @Test public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException { final UnsafeShuffleWriter writer = createWriter(false); - writer.insertRecordIntoSorter(new Tuple2(1, 1)); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(new Tuple2<>(1, 1)); + writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); writer.forceSorterToSpill(); - writer.insertRecordIntoSorter(new Tuple2(2, 2)); + writer.insertRecordIntoSorter(new Tuple2<>(2, 2)); writer.stop(false); assertSpillFilesWereCleanedUp(); } @@ -531,4 +559,5 @@ public void testPeakMemoryUsed() throws Exception { writer.stop(false); } } + } diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 33709b454c4c9..03cec8ed81b72 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,13 +19,10 @@ import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; -import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Assert; @@ -33,8 +30,6 @@ import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.SparkConf; import org.apache.spark.executor.ShuffleWriteMetrics; @@ -75,13 +70,6 @@ public abstract class AbstractBytesToBytesMapSuite { @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - private static final class WrapStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - @Before public void setup() { memoryManager = @@ -97,38 +85,30 @@ public void setup() { spillFilesCreated.clear(); MockitoAnnotations.initMocks(this); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer( - new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) - throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } + when(diskBlockManager.createTempLocalBlock()).thenAnswer(invocationOnMock -> { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); }); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] ); - } - }); + }); } @After diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a9cf8ff520ed4..771d39016c188 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,22 +19,17 @@ import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; import java.util.UUID; -import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; @@ -57,13 +52,15 @@ public class UnsafeExternalSorterSuite { + private final SparkConf conf = new SparkConf(); + final LinkedList spillFilesCreated = new LinkedList<>(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final SerializerManager serializerManager = new SerializerManager( - new JavaSerializer(new SparkConf()), - new SparkConf().set("spark.shuffle.spill.compress", "false")); + new JavaSerializer(conf), + conf.clone().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = PrefixComparators.LONG; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -86,14 +83,7 @@ public int compare( protected boolean shouldUseRadixSort() { return false; } - private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); - - private static final class WrapStream extends AbstractFunction1 { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } + private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m"); @Before public void setUp() { @@ -103,38 +93,30 @@ public void setUp() { taskContext = mock(TaskContext.class); when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); - when(diskBlockManager.createTempLocalBlock()).thenAnswer( - new Answer>() { - @Override - public Tuple2 answer(InvocationOnMock invocationOnMock) - throws Throwable { - TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); - File file = File.createTempFile("spillFile", ".spill", tempDir); - spillFilesCreated.add(file); - return Tuple2$.MODULE$.apply(blockId, file); - } + when(diskBlockManager.createTempLocalBlock()).thenAnswer(invocationOnMock -> { + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); + File file = File.createTempFile("spillFile", ".spill", tempDir); + spillFilesCreated.add(file); + return Tuple2$.MODULE$.apply(blockId, file); }); when(blockManager.getDiskWriter( any(BlockId.class), any(File.class), any(SerializerInstance.class), anyInt(), - any(ShuffleWriteMetrics.class))).thenAnswer(new Answer() { - @Override - public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { + any(ShuffleWriteMetrics.class))).thenAnswer(invocationOnMock -> { Object[] args = invocationOnMock.getArguments(); return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] ); - } - }); + }); } @After diff --git a/external/java8-tests/src/test/java/test/org/apache/spark/java8/Java8RDDAPISuite.java b/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java similarity index 98% rename from external/java8-tests/src/test/java/test/org/apache/spark/java8/Java8RDDAPISuite.java rename to core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java index fa3a66e73ced6..1d2b05ebc2503 100644 --- a/external/java8-tests/src/test/java/test/org/apache/spark/java8/Java8RDDAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/Java8RDDAPISuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package test.org.apache.spark.java8; +package test.org.apache.spark; import java.io.File; import java.io.Serializable; @@ -64,12 +64,7 @@ public void tearDown() { public void foreachWithAnonymousClass() { foreachCalls = 0; JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - foreachCalls++; - } - }); + rdd.foreach(s -> foreachCalls++); Assert.assertEquals(2, foreachCalls); } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java similarity index 74% rename from core/src/test/java/org/apache/spark/JavaAPISuite.java rename to core/src/test/java/test/org/apache/spark/JavaAPISuite.java index 533025ba83e72..01b5fb7b46684 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark; +package test.org.apache.spark; import java.io.*; import java.nio.channels.FileChannel; import java.nio.ByteBuffer; -import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -32,9 +31,14 @@ import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.*; +import org.apache.spark.Accumulator; +import org.apache.spark.AccumulatorParam; +import org.apache.spark.Partitioner; +import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -46,6 +50,7 @@ import com.google.common.collect.Lists; import com.google.common.base.Throwables; import com.google.common.io.Files; +import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.compress.DefaultCodec; @@ -202,7 +207,7 @@ public void sortByKey() { assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // Custom comparator - sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); + sortedRDD = rdd.sortByKey(Collections.reverseOrder(), false); assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); assertEquals(new Tuple2<>(0, 4), sortedPairs.get(1)); @@ -260,13 +265,7 @@ public void sortBy() { JavaRDD> rdd = sc.parallelize(pairs); // compare on first value - JavaRDD> sortedRDD = - rdd.sortBy(new Function, Integer>() { - @Override - public Integer call(Tuple2 t) { - return t._1(); - } - }, true, 2); + JavaRDD> sortedRDD = rdd.sortBy(Tuple2::_1, true, 2); assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); List> sortedPairs = sortedRDD.collect(); @@ -274,12 +273,7 @@ public Integer call(Tuple2 t) { assertEquals(new Tuple2<>(3, 2), sortedPairs.get(2)); // compare on second value - sortedRDD = rdd.sortBy(new Function, Integer>() { - @Override - public Integer call(Tuple2 t) { - return t._2(); - } - }, true, 2); + sortedRDD = rdd.sortBy(Tuple2::_2, true, 2); assertEquals(new Tuple2<>(-1, 1), sortedRDD.first()); sortedPairs = sortedRDD.collect(); assertEquals(new Tuple2<>(3, 2), sortedPairs.get(1)); @@ -288,28 +282,20 @@ public Integer call(Tuple2 t) { @Test public void foreach() { - final LongAccumulator accum = sc.sc().longAccumulator(); + LongAccumulator accum = sc.sc().longAccumulator(); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreach(new VoidFunction() { - @Override - public void call(String s) { - accum.add(1); - } - }); + rdd.foreach(s -> accum.add(1)); assertEquals(2, accum.value().intValue()); } @Test public void foreachPartition() { - final LongAccumulator accum = sc.sc().longAccumulator(); + LongAccumulator accum = sc.sc().longAccumulator(); JavaRDD rdd = sc.parallelize(Arrays.asList("Hello", "World")); - rdd.foreachPartition(new VoidFunction>() { - @Override - public void call(Iterator iter) { - while (iter.hasNext()) { - iter.next(); - accum.add(1); - } + rdd.foreachPartition(iter -> { + while (iter.hasNext()) { + iter.next(); + accum.add(1); } }); assertEquals(2, accum.value().intValue()); @@ -355,12 +341,7 @@ public void lookup() { @Test public void groupBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function isOdd = new Function() { - @Override - public Boolean call(Integer x) { - return x % 2 == 0; - } - }; + Function isOdd = x -> x % 2 == 0; JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); assertEquals(2, oddsAndEvens.count()); assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens @@ -377,12 +358,7 @@ public void groupByOnPairRDD() { // Regression test for SPARK-4459 JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Function, Boolean> areOdd = - new Function, Boolean>() { - @Override - public Boolean call(Tuple2 x) { - return (x._1() % 2 == 0) && (x._2() % 2 == 0); - } - }; + x -> (x._1() % 2 == 0) && (x._2() % 2 == 0); JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD>> oddsAndEvens = pairRDD.groupBy(areOdd); assertEquals(2, oddsAndEvens.count()); @@ -400,13 +376,7 @@ public Boolean call(Tuple2 x) { public void keyByOnPairRDD() { // Regression test for SPARK-4459 JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function, String> sumToString = - new Function, String>() { - @Override - public String call(Tuple2 x) { - return String.valueOf(x._1() + x._2()); - } - }; + Function, String> sumToString = x -> String.valueOf(x._1() + x._2()); JavaPairRDD pairRDD = rdd.zip(rdd); JavaPairRDD> keyed = pairRDD.keyBy(sumToString); assertEquals(7, keyed.count()); @@ -510,25 +480,14 @@ public void leftOuterJoin() { rdd1.leftOuterJoin(rdd2).collect(); assertEquals(5, joined.size()); Tuple2>> firstUnmatched = - rdd1.leftOuterJoin(rdd2).filter( - new Function>>, Boolean>() { - @Override - public Boolean call(Tuple2>> tup) { - return !tup._2()._2().isPresent(); - } - }).first(); + rdd1.leftOuterJoin(rdd2).filter(tup -> !tup._2()._2().isPresent()).first(); assertEquals(3, firstUnmatched._1().intValue()); } @Test public void foldReduce() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; + Function2 add = (a, b) -> a + b; int sum = rdd.fold(0, add); assertEquals(33, sum); @@ -540,12 +499,7 @@ public Integer call(Integer a, Integer b) { @Test public void treeReduce() { JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; + Function2 add = (a, b) -> a + b; for (int depth = 1; depth <= 10; depth++) { int sum = rdd.treeReduce(add, depth); assertEquals(-5, sum); @@ -555,12 +509,7 @@ public Integer call(Integer a, Integer b) { @Test public void treeAggregate() { JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); - Function2 add = new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }; + Function2 add = (a, b) -> a + b; for (int depth = 1; depth <= 10; depth++) { int sum = rdd.treeAggregate(0, add, add, depth); assertEquals(-5, sum); @@ -578,21 +527,15 @@ public void aggregateByKey() { new Tuple2<>(5, 1), new Tuple2<>(5, 3)), 2); - Map> sets = pairs.aggregateByKey(new HashSet(), - new Function2, Integer, Set>() { - @Override - public Set call(Set a, Integer b) { - a.add(b); - return a; - } - }, - new Function2, Set, Set>() { - @Override - public Set call(Set a, Set b) { - a.addAll(b); - return a; - } - }).collectAsMap(); + Map> sets = pairs.aggregateByKey(new HashSet(), + (a, b) -> { + a.add(b); + return a; + }, + (a, b) -> { + a.addAll(b); + return a; + }).collectAsMap(); assertEquals(3, sets.size()); assertEquals(new HashSet<>(Arrays.asList(1)), sets.get(1)); assertEquals(new HashSet<>(Arrays.asList(2)), sets.get(3)); @@ -610,13 +553,7 @@ public void foldByKey() { new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD sums = rdd.foldByKey(0, - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); + JavaPairRDD sums = rdd.foldByKey(0, (a, b) -> a + b); assertEquals(1, sums.lookup(1).get(0).intValue()); assertEquals(2, sums.lookup(2).get(0).intValue()); assertEquals(3, sums.lookup(3).get(0).intValue()); @@ -633,13 +570,7 @@ public void reduceByKey() { new Tuple2<>(3, 1) ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - JavaPairRDD counts = rdd.reduceByKey( - new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); + JavaPairRDD counts = rdd.reduceByKey((a, b) -> a + b); assertEquals(1, counts.lookup(1).get(0).intValue()); assertEquals(2, counts.lookup(2).get(0).intValue()); assertEquals(3, counts.lookup(3).get(0).intValue()); @@ -649,12 +580,7 @@ public Integer call(Integer a, Integer b) { assertEquals(2, localCounts.get(2).intValue()); assertEquals(3, localCounts.get(3).intValue()); - localCounts = rdd.reduceByKeyLocally(new Function2() { - @Override - public Integer call(Integer a, Integer b) { - return a + b; - } - }); + localCounts = rdd.reduceByKeyLocally((a, b) -> a + b); assertEquals(1, localCounts.get(1).intValue()); assertEquals(2, localCounts.get(2).intValue()); assertEquals(3, localCounts.get(3).intValue()); @@ -686,20 +612,8 @@ public void isEmpty() { assertTrue(sc.emptyRDD().isEmpty()); assertTrue(sc.parallelize(new ArrayList()).isEmpty()); assertFalse(sc.parallelize(Arrays.asList(1)).isEmpty()); - assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter( - new Function() { - @Override - public Boolean call(Integer i) { - return i < 0; - } - }).isEmpty()); - assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter( - new Function() { - @Override - public Boolean call(Integer i) { - return i > 1; - } - }).isEmpty()); + assertTrue(sc.parallelize(Arrays.asList(1, 2, 3), 3).filter(i -> i < 0).isEmpty()); + assertFalse(sc.parallelize(Arrays.asList(1, 2, 3)).filter(i -> i > 1).isEmpty()); } @Test @@ -715,12 +629,7 @@ public void javaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); JavaDoubleRDD distinct = rdd.distinct(); assertEquals(5, distinct.count()); - JavaDoubleRDD filter = rdd.filter(new Function() { - @Override - public Boolean call(Double x) { - return x > 2.0; - } - }); + JavaDoubleRDD filter = rdd.filter(x -> x > 2.0); assertEquals(3, filter.count()); JavaDoubleRDD union = rdd.union(rdd); assertEquals(12, union.count()); @@ -757,8 +666,8 @@ public void javaDoubleRDDHistoGram() { assertArrayEquals(expected_counts, histogram); // SPARK-5744 assertArrayEquals( - new long[] {0}, - sc.parallelizeDoubles(new ArrayList(0), 1).histogram(new double[]{0.0, 1.0})); + new long[] {0}, + sc.parallelizeDoubles(new ArrayList<>(0), 1).histogram(new double[]{0.0, 1.0})); } private static class DoubleComparator implements Comparator, Serializable { @@ -827,12 +736,7 @@ public void reduce() { @Test public void reduceOnJavaDoubleRDD() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); - double sum = rdd.reduce(new Function2() { - @Override - public Double call(Double v1, Double v2) { - return v1 + v2; - } - }); + double sum = rdd.reduce((v1, v2) -> v1 + v2); assertEquals(10.0, sum, 0.001); } @@ -853,27 +757,11 @@ public void aggregate() { @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { - @Override - public double call(Integer x) { - return x.doubleValue(); - } - }).cache(); + JavaDoubleRDD doubles = rdd.mapToDouble(Integer::doubleValue).cache(); doubles.collect(); - JavaPairRDD pairs = rdd.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2<>(x, x); - } - }).cache(); + JavaPairRDD pairs = rdd.mapToPair(x -> new Tuple2<>(x, x)).cache(); pairs.collect(); - JavaRDD strings = rdd.map(new Function() { - @Override - public String call(Integer x) { - return x.toString(); - } - }).cache(); + JavaRDD strings = rdd.map(Object::toString).cache(); strings.collect(); } @@ -881,39 +769,27 @@ public String call(Integer x) { public void flatMap() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", "The quick brown fox jumps over the lazy dog.")); - JavaRDD words = rdd.flatMap(new FlatMapFunction() { - @Override - public Iterator call(String x) { - return Arrays.asList(x.split(" ")).iterator(); - } - }); + JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" ")).iterator()); assertEquals("Hello", words.first()); assertEquals(11, words.count()); - JavaPairRDD pairsRDD = rdd.flatMapToPair( - new PairFlatMapFunction() { - @Override - public Iterator> call(String s) { - List> pairs = new LinkedList<>(); - for (String word : s.split(" ")) { - pairs.add(new Tuple2<>(word, word)); - } - return pairs.iterator(); + JavaPairRDD pairsRDD = rdd.flatMapToPair(s -> { + List> pairs = new LinkedList<>(); + for (String word : s.split(" ")) { + pairs.add(new Tuple2<>(word, word)); } + return pairs.iterator(); } ); assertEquals(new Tuple2<>("Hello", "Hello"), pairsRDD.first()); assertEquals(11, pairsRDD.count()); - JavaDoubleRDD doubles = rdd.flatMapToDouble(new DoubleFlatMapFunction() { - @Override - public Iterator call(String s) { - List lengths = new LinkedList<>(); - for (String word : s.split(" ")) { - lengths.add((double) word.length()); - } - return lengths.iterator(); + JavaDoubleRDD doubles = rdd.flatMapToDouble(s -> { + List lengths = new LinkedList<>(); + for (String word : s.split(" ")) { + lengths.add((double) word.length()); } + return lengths.iterator(); }); assertEquals(5.0, doubles.first(), 0.01); assertEquals(11, pairsRDD.count()); @@ -931,37 +807,23 @@ public void mapsFromPairsToPairs() { // Regression test for SPARK-668: JavaPairRDD swapped = pairRDD.flatMapToPair( - new PairFlatMapFunction, String, Integer>() { - @Override - public Iterator> call(Tuple2 item) { - return Collections.singletonList(item.swap()).iterator(); - } - }); + item -> Collections.singletonList(item.swap()).iterator()); swapped.collect(); // There was never a bug here, but it's worth testing: - pairRDD.mapToPair(new PairFunction, String, Integer>() { - @Override - public Tuple2 call(Tuple2 item) { - return item.swap(); - } - }).collect(); + pairRDD.mapToPair(Tuple2::swap).collect(); } @Test public void mapPartitions() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitions( - new FlatMapFunction, Integer>() { - @Override - public Iterator call(Iterator iter) { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum).iterator(); + JavaRDD partitionSums = rdd.mapPartitions(iter -> { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); } - }); + return Collections.singletonList(sum).iterator(); + }); assertEquals("[3, 7]", partitionSums.collect().toString()); } @@ -969,17 +831,13 @@ public Iterator call(Iterator iter) { @Test public void mapPartitionsWithIndex() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2); - JavaRDD partitionSums = rdd.mapPartitionsWithIndex( - new Function2, Iterator>() { - @Override - public Iterator call(Integer index, Iterator iter) { - int sum = 0; - while (iter.hasNext()) { - sum += iter.next(); - } - return Collections.singletonList(sum).iterator(); + JavaRDD partitionSums = rdd.mapPartitionsWithIndex((index, iter) -> { + int sum = 0; + while (iter.hasNext()) { + sum += iter.next(); } - }, false); + return Collections.singletonList(sum).iterator(); + }, false); assertEquals("[3, 7]", partitionSums.collect().toString()); } @@ -987,11 +845,13 @@ public Iterator call(Integer index, Iterator iter) { public void getNumPartitions(){ JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); JavaDoubleRDD rdd2 = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0), 2); - JavaPairRDD rdd3 = sc.parallelizePairs(Arrays.asList( - new Tuple2<>("a", 1), - new Tuple2<>("aa", 2), - new Tuple2<>("aaa", 3) - ), 2); + JavaPairRDD rdd3 = sc.parallelizePairs( + Arrays.asList( + new Tuple2<>("a", 1), + new Tuple2<>("aa", 2), + new Tuple2<>("aaa", 3) + ), + 2); assertEquals(3, rdd1.getNumPartitions()); assertEquals(2, rdd2.getNumPartitions()); assertEquals(2, rdd3.getNumPartitions()); @@ -1075,18 +935,23 @@ public void wholeTextFiles() throws Exception { byte[] content2 = "spark is also easy to use.\n".getBytes(StandardCharsets.UTF_8); String tempDirName = tempDir.getAbsolutePath(); - Files.write(content1, new File(tempDirName + "/part-00000")); - Files.write(content2, new File(tempDirName + "/part-00001")); + String path1 = new Path(tempDirName, "part-00000").toUri().getPath(); + String path2 = new Path(tempDirName, "part-00001").toUri().getPath(); + + Files.write(content1, new File(path1)); + Files.write(content2, new File(path2)); Map container = new HashMap<>(); - container.put(tempDirName+"/part-00000", new Text(content1).toString()); - container.put(tempDirName+"/part-00001", new Text(content2).toString()); + container.put(path1, new Text(content1).toString()); + container.put(path2, new Text(content2).toString()); JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName, 3); List> result = readRDD.collect(); for (Tuple2 res : result) { - assertEquals(res._2(), container.get(new URI(res._1()).getPath())); + // Note that the paths from `wholeTextFiles` are in URI format on Windows, + // for example, file:/C:/a/b/c. + assertEquals(res._2(), container.get(new Path(res._1()).toUri().getPath())); } } @@ -1113,21 +978,12 @@ public void sequenceFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); // Try reading the output back as an object file JavaPairRDD readRDD = sc.sequenceFile(outputDir, IntWritable.class, - Text.class).mapToPair(new PairFunction, Integer, String>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(pair._1().get(), pair._2().toString()); - } - }); + Text.class).mapToPair(pair -> new Tuple2<>(pair._1().get(), pair._2().toString())); assertEquals(pairs, readRDD.collect()); } @@ -1168,12 +1024,7 @@ public void binaryFilesCaching() throws Exception { channel1.close(); JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache(); - readRDD.foreach(new VoidFunction>() { - @Override - public void call(Tuple2 pair) { - pair._2().toArray(); // force the file to read - } - }); + readRDD.foreach(pair -> pair._2().toArray()); // force the file to read List> result = readRDD.collect(); for (Tuple2 res : result) { @@ -1218,23 +1069,13 @@ public void writeWithNewAPIHadoopFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsNewAPIHadoopFile( - outputDir, IntWritable.class, Text.class, + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsNewAPIHadoopFile(outputDir, IntWritable.class, Text.class, org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class); JavaPairRDD output = - sc.sequenceFile(outputDir, IntWritable.class, Text.class); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + sc.sequenceFile(outputDir, IntWritable.class, Text.class); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @SuppressWarnings("unchecked") @@ -1248,22 +1089,13 @@ public void readWithNewAPIHadoopFile() throws IOException { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.newAPIHadoopFile(outputDir, - org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, Job.getInstance().getConfiguration()); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @Test @@ -1304,21 +1136,12 @@ public void hadoopFile() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + SequenceFileInputFormat.class, IntWritable.class, Text.class); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @SuppressWarnings("unchecked") @@ -1332,34 +1155,20 @@ public void hadoopFileCompressed() { ); JavaPairRDD rdd = sc.parallelizePairs(pairs); - rdd.mapToPair(new PairFunction, IntWritable, Text>() { - @Override - public Tuple2 call(Tuple2 pair) { - return new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2())); - } - }).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class, - DefaultCodec.class); + rdd.mapToPair(pair -> new Tuple2<>(new IntWritable(pair._1()), new Text(pair._2()))) + .saveAsHadoopFile(outputDir, IntWritable.class, Text.class, + SequenceFileOutputFormat.class, DefaultCodec.class); JavaPairRDD output = sc.hadoopFile(outputDir, - SequenceFileInputFormat.class, IntWritable.class, Text.class); + SequenceFileInputFormat.class, IntWritable.class, Text.class); - assertEquals(pairs.toString(), output.map(new Function, String>() { - @Override - public String call(Tuple2 x) { - return x.toString(); - } - }).collect().toString()); + assertEquals(pairs.toString(), output.map(Tuple2::toString).collect().toString()); } @Test public void zip() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - JavaDoubleRDD doubles = rdd.mapToDouble(new DoubleFunction() { - @Override - public double call(Integer x) { - return x.doubleValue(); - } - }); + JavaDoubleRDD doubles = rdd.mapToDouble(Integer::doubleValue); JavaPairRDD zipped = rdd.zip(doubles); zipped.count(); } @@ -1369,12 +1178,7 @@ public void zipPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6), 2); JavaRDD rdd2 = sc.parallelize(Arrays.asList("1", "2", "3", "4"), 2); FlatMapFunction2, Iterator, Integer> sizesFn = - new FlatMapFunction2, Iterator, Integer>() { - @Override - public Iterator call(Iterator i, Iterator s) { - return Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); - } - }; + (i, s) -> Arrays.asList(Iterators.size(i), Iterators.size(s)).iterator(); JavaRDD sizes = rdd1.zipPartitions(rdd2, sizesFn); assertEquals("[3, 2, 3, 2]", sizes.collect().toString()); @@ -1385,22 +1189,12 @@ public Iterator call(Iterator i, Iterator s) { public void accumulators() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); - final Accumulator intAccum = sc.intAccumulator(10); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer x) { - intAccum.add(x); - } - }); + Accumulator intAccum = sc.intAccumulator(10); + rdd.foreach(intAccum::add); assertEquals((Integer) 25, intAccum.value()); - final Accumulator doubleAccum = sc.doubleAccumulator(10.0); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer x) { - doubleAccum.add((double) x); - } - }); + Accumulator doubleAccum = sc.doubleAccumulator(10.0); + rdd.foreach(x -> doubleAccum.add((double) x)); assertEquals((Double) 25.0, doubleAccum.value()); // Try a custom accumulator type @@ -1421,13 +1215,8 @@ public Float zero(Float initialValue) { } }; - final Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); - rdd.foreach(new VoidFunction() { - @Override - public void call(Integer x) { - floatAccum.add((float) x); - } - }); + Accumulator floatAccum = sc.accumulator(10.0f, floatAccumulatorParam); + rdd.foreach(x -> floatAccum.add((float) x)); assertEquals((Float) 25.0f, floatAccum.value()); // Test the setValue method @@ -1438,12 +1227,7 @@ public void call(Integer x) { @Test public void keyBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2)); - List> s = rdd.keyBy(new Function() { - @Override - public String call(Integer t) { - return t.toString(); - } - }).collect(); + List> s = rdd.keyBy(Object::toString).collect(); assertEquals(new Tuple2<>("1", 1), s.get(0)); assertEquals(new Tuple2<>("2", 2), s.get(1)); } @@ -1476,45 +1260,29 @@ public void checkpointAndRestore() { @Test public void combineByKey() { JavaRDD originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6)); - Function keyFunction = new Function() { - @Override - public Integer call(Integer v1) { - return v1 % 3; - } - }; - Function createCombinerFunction = new Function() { - @Override - public Integer call(Integer v1) { - return v1; - } - }; + Function keyFunction = v1 -> v1 % 3; + Function createCombinerFunction = v1 -> v1; - Function2 mergeValueFunction = - new Function2() { - @Override - public Integer call(Integer v1, Integer v2) { - return v1 + v2; - } - }; + Function2 mergeValueFunction = (v1, v2) -> v1 + v2; JavaPairRDD combinedRDD = originalRDD.keyBy(keyFunction) - .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); + .combineByKey(createCombinerFunction, mergeValueFunction, mergeValueFunction); Map results = combinedRDD.collectAsMap(); ImmutableMap expected = ImmutableMap.of(0, 9, 1, 5, 2, 7); assertEquals(expected, results); Partitioner defaultPartitioner = Partitioner.defaultPartitioner( - combinedRDD.rdd(), - JavaConverters.collectionAsScalaIterableConverter( - Collections.>emptyList()).asScala().toSeq()); + combinedRDD.rdd(), + JavaConverters.collectionAsScalaIterableConverter( + Collections.>emptyList()).asScala().toSeq()); combinedRDD = originalRDD.keyBy(keyFunction) - .combineByKey( - createCombinerFunction, - mergeValueFunction, - mergeValueFunction, - defaultPartitioner, - false, - new KryoSerializer(new SparkConf())); + .combineByKey( + createCombinerFunction, + mergeValueFunction, + mergeValueFunction, + defaultPartitioner, + false, + new KryoSerializer(new SparkConf())); results = combinedRDD.collectAsMap(); assertEquals(expected, results); } @@ -1523,26 +1291,13 @@ public Integer call(Integer v1, Integer v2) { @Test public void mapOnPairRDD() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1,2,3,4)); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i, i % 2); - } - }); - JavaPairRDD rdd3 = rdd2.mapToPair( - new PairFunction, Integer, Integer>() { - @Override - public Tuple2 call(Tuple2 in) { - return new Tuple2<>(in._2(), in._1()); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); + JavaPairRDD rdd3 = rdd2.mapToPair(in -> new Tuple2<>(in._2(), in._1())); assertEquals(Arrays.asList( - new Tuple2<>(1, 1), - new Tuple2<>(0, 2), - new Tuple2<>(1, 3), - new Tuple2<>(0, 4)), rdd3.collect()); - + new Tuple2<>(1, 1), + new Tuple2<>(0, 2), + new Tuple2<>(1, 3), + new Tuple2<>(0, 4)), rdd3.collect()); } @SuppressWarnings("unchecked") @@ -1550,13 +1305,7 @@ public Tuple2 call(Tuple2 in) { public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i, i % 2); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i, i % 2)); List[] parts = rdd1.collectPartitions(new int[] {0}); assertEquals(Arrays.asList(1, 2), parts[0]); @@ -1565,16 +1314,18 @@ public Tuple2 call(Integer i) { assertEquals(Arrays.asList(3, 4), parts[0]); assertEquals(Arrays.asList(5, 6, 7), parts[1]); - assertEquals(Arrays.asList(new Tuple2<>(1, 1), - new Tuple2<>(2, 0)), - rdd2.collectPartitions(new int[] {0})[0]); + assertEquals( + Arrays.asList(new Tuple2<>(1, 1), new Tuple2<>(2, 0)), + rdd2.collectPartitions(new int[] {0})[0]); List>[] parts2 = rdd2.collectPartitions(new int[] {1, 2}); assertEquals(Arrays.asList(new Tuple2<>(3, 1), new Tuple2<>(4, 0)), parts2[0]); - assertEquals(Arrays.asList(new Tuple2<>(5, 1), - new Tuple2<>(6, 0), - new Tuple2<>(7, 1)), - parts2[1]); + assertEquals( + Arrays.asList( + new Tuple2<>(5, 1), + new Tuple2<>(6, 0), + new Tuple2<>(7, 1)), + parts2[1]); } @Test @@ -1605,20 +1356,13 @@ public void countApproxDistinctByKey() { double error = Math.abs((resCount - count) / count); assertTrue(error < 0.1); } - } @Test public void collectAsMapWithIntArrayValues() { // Regression test for SPARK-1040 JavaRDD rdd = sc.parallelize(Arrays.asList(1)); - JavaPairRDD pairRDD = rdd.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer x) { - return new Tuple2<>(x, new int[]{x}); - } - }); + JavaPairRDD pairRDD = rdd.mapToPair(x -> new Tuple2<>(x, new int[]{x})); pairRDD.collect(); // Works fine pairRDD.collectAsMap(); // Used to crash with ClassCastException } @@ -1640,13 +1384,7 @@ public void collectAsMapAndSerialize() throws Exception { @SuppressWarnings("unchecked") public void sampleByKey() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i % 2, 1); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i % 2, 1)); Map fractions = new HashMap<>(); fractions.put(0, 0.5); fractions.put(1, 1.0); @@ -1666,13 +1404,7 @@ public Tuple2 call(Integer i) { @SuppressWarnings("unchecked") public void sampleByKeyExact() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 3); - JavaPairRDD rdd2 = rdd1.mapToPair( - new PairFunction() { - @Override - public Tuple2 call(Integer i) { - return new Tuple2<>(i % 2, 1); - } - }); + JavaPairRDD rdd2 = rdd1.mapToPair(i -> new Tuple2<>(i % 2, 1)); Map fractions = new HashMap<>(); fractions.put(0, 0.5); fractions.put(1, 1.0); @@ -1743,14 +1475,7 @@ public void takeAsync() throws Exception { public void foreachAsync() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD rdd = sc.parallelize(data, 1); - JavaFutureAction future = rdd.foreachAsync( - new VoidFunction() { - @Override - public void call(Integer integer) { - // intentionally left blank. - } - } - ); + JavaFutureAction future = rdd.foreachAsync(integer -> {}); future.get(); assertFalse(future.isCancelled()); assertTrue(future.isDone()); @@ -1773,11 +1498,8 @@ public void countAsync() throws Exception { public void testAsyncActionCancellation() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD rdd = sc.parallelize(data, 1); - JavaFutureAction future = rdd.foreachAsync(new VoidFunction() { - @Override - public void call(Integer integer) throws InterruptedException { - Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. - } + JavaFutureAction future = rdd.foreachAsync(integer -> { + Thread.sleep(10000); // To ensure that the job won't finish before it's cancelled. }); future.cancel(true); assertTrue(future.isCancelled()); @@ -1794,7 +1516,7 @@ public void call(Integer integer) throws InterruptedException { public void testAsyncActionErrorWrapping() throws Exception { List data = Arrays.asList(1, 2, 3, 4, 5); JavaRDD rdd = sc.parallelize(data, 1); - JavaFutureAction future = rdd.map(new BuggyMapFunction()).countAsync(); + JavaFutureAction future = rdd.map(new BuggyMapFunction<>()).countAsync(); try { future.get(2, TimeUnit.SECONDS); fail("Expected future.get() for failed job to throw ExcecutionException"); @@ -1812,8 +1534,8 @@ public void testRegisterKryoClasses() { SparkConf conf = new SparkConf(); conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class }); assertEquals( - Class1.class.getName() + "," + Class2.class.getName(), - conf.get("spark.kryo.classesToRegister")); + Class1.class.getName() + "," + Class2.class.getName(), + conf.get("spark.kryo.classesToRegister")); } @Test diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index cba44c848e012..f2c3ec5da8891 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "app-20161116163331-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", + "lastUpdated" : "", + "duration" : 10671, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479335620587, + "startTimeEpoch" : 1479335609916, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "app-20161115172038-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-15T23:20:37.079GMT", + "endTime" : "2016-11-15T23:22:18.874GMT", + "lastUpdated" : "", + "duration" : 101795, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479252138874, + "startTimeEpoch" : 1479252037079, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { @@ -8,8 +38,9 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380893, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 } ] }, { @@ -23,8 +54,9 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380893, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380950, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -34,8 +66,9 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380880, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380890, + "startTimeEpoch" : 1430917380880, "lastUpdatedEpoch" : 0 } ] }, { @@ -49,8 +82,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426633910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, + "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -60,8 +94,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426533910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, + "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 } ] }, { @@ -74,8 +109,9 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1425081758277, + "appSparkVersion" : "", "endTimeEpoch" : 1425081766912, + "startTimeEpoch" : 1425081758277, "lastUpdatedEpoch" : 0 } ] }, { @@ -88,8 +124,9 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981779720, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, + "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 } ] }, { @@ -102,8 +139,9 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981758277, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, + "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index cba44c848e012..c925c1dd8a4d3 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "app-20161116163331-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", + "lastUpdated" : "", + "duration" : 10671, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479335620587, + "startTimeEpoch" : 1479335609916, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "app-20161115172038-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-15T23:20:37.079GMT", + "endTime" : "2016-11-15T23:22:18.874GMT", + "lastUpdated" : "", + "duration" : 101795, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479252138874, + "startTimeEpoch" : 1479252037079, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { @@ -8,8 +38,9 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380893, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 } ] }, { @@ -23,8 +54,9 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380893, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380950, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -34,8 +66,9 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380880, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380890, + "startTimeEpoch" : 1430917380880, "lastUpdatedEpoch" : 0 } ] }, { @@ -49,8 +82,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426633910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, + "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -60,8 +94,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426533910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, + "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 } ] }, { @@ -74,8 +109,10 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1425081758277, + "appSparkVersion" : "", + "appSparkVersion" : "", "endTimeEpoch" : 1425081766912, + "startTimeEpoch" : 1425081758277, "lastUpdatedEpoch" : 0 } ] }, { @@ -88,8 +125,9 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981779720, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, + "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 } ] }, { @@ -102,8 +140,9 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981758277, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, + "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index e7db6742c25e1..6b9f29e1a230e 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -16,6 +16,7 @@ "totalInputBytes" : 28000288, "totalShuffleRead" : 0, "totalShuffleWrite" : 13180, + "isBlacklisted" : false, "maxMemory" : 278302556, "executorLogs" : { } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json new file mode 100644 index 0000000000000..0f94e3b255dbc --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -0,0 +1,148 @@ +[ { + "id" : "2", + "hostPort" : "172.22.0.167:51487", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "1", + "hostPort" : "172.22.0.167:51490", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3152, + "totalGCTime" : 68, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", + "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "0", + "hostPort" : "172.22.0.167:51491", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2551, + "totalGCTime" : 116, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json new file mode 100644 index 0000000000000..0f94e3b255dbc --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -0,0 +1,148 @@ +[ { + "id" : "2", + "hostPort" : "172.22.0.167:51487", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2537, + "totalGCTime" : 88, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "driver", + "hostPort" : "172.22.0.167:51475", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "1", + "hostPort" : "172.22.0.167:51490", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3152, + "totalGCTime" : 68, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", + "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "0", + "hostPort" : "172.22.0.167:51491", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2551, + "totalGCTime" : 116, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +}, { + "id" : "3", + "hostPort" : "172.22.0.167:51485", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2453, + "totalGCTime" : 72, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : true, + "maxMemory" : 908381388, + "executorLogs" : { + "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" + }, + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json new file mode 100644 index 0000000000000..92e249c851116 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_unblacklisting_expectation.json @@ -0,0 +1,118 @@ +[ { + "id" : "2", + "hostPort" : "172.22.0.111:64539", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 6, + "completedTasks" : 0, + "totalTasks" : 6, + "totalDuration" : 2792, + "totalGCTime" : 128, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout", + "stderr" : "http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr" + } +}, { + "id" : "driver", + "hostPort" : "172.22.0.111:64527", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 0, + "maxTasks" : 0, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 0, + "totalTasks" : 0, + "totalDuration" : 0, + "totalGCTime" : 0, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "executorLogs" : { } +}, { + "id" : "1", + "hostPort" : "172.22.0.111:64541", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 12, + "totalTasks" : 12, + "totalDuration" : 2613, + "totalGCTime" : 84, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout", + "stderr" : "http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr" + } +}, { + "id" : "0", + "hostPort" : "172.22.0.111:64540", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 4, + "completedTasks" : 0, + "totalTasks" : 4, + "totalDuration" : 2741, + "totalGCTime" : 120, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout", + "stderr" : "http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr" + } +}, { + "id" : "3", + "hostPort" : "172.22.0.111:64543", + "isActive" : true, + "rddBlocks" : 0, + "memoryUsed" : 0, + "diskUsed" : 0, + "totalCores" : 4, + "maxTasks" : 4, + "activeTasks" : 0, + "failedTasks" : 0, + "completedTasks" : 4, + "totalTasks" : 4, + "totalDuration" : 3457, + "totalGCTime" : 72, + "totalInputBytes" : 0, + "totalShuffleRead" : 0, + "totalShuffleWrite" : 0, + "isBlacklisted" : false, + "maxMemory" : 384093388, + "executorLogs" : { + "stdout" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout", + "stderr" : "http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr" + } +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json index 9165f549d7d25..cc0b2b0022bd3 100644 --- a/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/limit_app_list_json_expectation.json @@ -1,67 +1,46 @@ [ { - "id" : "local-1430917381534", + "id" : "app-20161116163331-0000", "name" : "Spark shell", "attempts" : [ { - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:11.398GMT", + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", "lastUpdated" : "", - "duration" : 10505, - "sparkUser" : "irashid", + "duration" : 10671, + "sparkUser" : "jose", "completed" : true, - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917391398, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479335620587, + "startTimeEpoch" : 1479335609916, "lastUpdatedEpoch" : 0 } ] }, { - "id" : "local-1430917381535", + "id" : "app-20161115172038-0000", "name" : "Spark shell", "attempts" : [ { - "attemptId" : "2", - "startTime" : "2015-05-06T13:03:00.893GMT", - "endTime" : "2015-05-06T13:03:00.950GMT", - "lastUpdated" : "", - "duration" : 57, - "sparkUser" : "irashid", - "completed" : true, - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917380950, - "lastUpdatedEpoch" : 0 - }, { - "attemptId" : "1", - "startTime" : "2015-05-06T13:03:00.880GMT", - "endTime" : "2015-05-06T13:03:00.890GMT", + "startTime" : "2016-11-15T23:20:37.079GMT", + "endTime" : "2016-11-15T23:22:18.874GMT", "lastUpdated" : "", - "duration" : 10, - "sparkUser" : "irashid", + "duration" : 101795, + "sparkUser" : "jose", "completed" : true, - "startTimeEpoch" : 1430917380880, - "endTimeEpoch" : 1430917380890, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479252138874, + "startTimeEpoch" : 1479252037079, "lastUpdatedEpoch" : 0 } ] }, { - "id" : "local-1426533911241", + "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { - "attemptId" : "2", - "startTime" : "2015-03-17T23:11:50.242GMT", - "endTime" : "2015-03-17T23:12:25.177GMT", - "lastUpdated" : "", - "duration" : 34935, - "sparkUser" : "irashid", - "completed" : true, - "startTimeEpoch" : 1426633910242, - "endTimeEpoch" : 1426633945177, - "lastUpdatedEpoch" : 0 - }, { - "attemptId" : "1", - "startTime" : "2015-03-16T19:25:10.242GMT", - "endTime" : "2015-03-16T19:25:45.177GMT", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", - "duration" : 34935, + "duration" : 10505, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426533910242, - "endTimeEpoch" : 1426533945177, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "endTimeEpoch" : 1430917391398, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index a525d61543a88..fa12413eeb0e6 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -8,8 +8,9 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981758277, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, + "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index cc567f66f02e8..a0d4a0d1c4554 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -8,8 +8,9 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981779720, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, + "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 } ] }, { @@ -22,8 +23,9 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981758277, + "appSparkVersion" : "", "endTimeEpoch" : 1422981766912, + "startTimeEpoch" : 1422981758277, "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json new file mode 100644 index 0000000000000..dfa90010c6ca1 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/maxEndDate_app_list_json_expectation.json @@ -0,0 +1,102 @@ +[ { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950 + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890 + } ] +}, { + "id" : "local-1426533911241", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-03-17T23:11:50.242GMT", + "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177 + }, { + "attemptId" : "1", + "startTime" : "2015-03-16T19:25:10.242GMT", + "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177 + } ] +}, { + "id" : "local-1425081759269", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-02-28T00:02:38.277GMT", + "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912 + } ] +}, { + "id" : "local-1422981780767", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-02-03T16:42:59.720GMT", + "endTime" : "2015-02-03T16:43:08.731GMT", + "lastUpdated" : "", + "duration" : 9011, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731 + } ] +}, { + "id" : "local-1422981759269", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-02-03T16:42:38.277GMT", + "endTime" : "2015-02-03T16:42:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912 + } ] +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json new file mode 100644 index 0000000000000..3ebe60e2cd033 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/minDate_and_maxEndDate_app_list_json_expectation.json @@ -0,0 +1,57 @@ +[ { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950 + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890 + } ] +}, { + "id" : "local-1426533911241", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-03-17T23:11:50.242GMT", + "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177 + }, { + "attemptId" : "1", + "startTime" : "2015-03-16T19:25:10.242GMT", + "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177 + } ] +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index c934a871724b5..5af50abd85330 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -1,4 +1,34 @@ [ { + "id" : "app-20161116163331-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", + "lastUpdated" : "", + "duration" : 10671, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479335620587, + "startTimeEpoch" : 1479335609916, + "lastUpdatedEpoch" : 0 + } ] +}, { + "id" : "app-20161115172038-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-15T23:20:37.079GMT", + "endTime" : "2016-11-15T23:22:18.874GMT", + "lastUpdated" : "", + "duration" : 101795, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "endTimeEpoch" : 1479252138874, + "startTimeEpoch" : 1479252037079, + "lastUpdatedEpoch" : 0 + } ] +}, { "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { @@ -8,8 +38,9 @@ "duration" : 10505, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380893, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917391398, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 } ] }, { @@ -23,8 +54,9 @@ "duration" : 57, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380893, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380950, + "startTimeEpoch" : 1430917380893, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -34,8 +66,9 @@ "duration" : 10, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1430917380880, + "appSparkVersion" : "1.4.0-SNAPSHOT", "endTimeEpoch" : 1430917380890, + "startTimeEpoch" : 1430917380880, "lastUpdatedEpoch" : 0 } ] }, { @@ -49,8 +82,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426633910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, + "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -60,8 +94,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426533910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, + "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 } ] }, { @@ -74,8 +109,9 @@ "duration" : 8635, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1425081758277, + "appSparkVersion" : "", "endTimeEpoch" : 1425081766912, + "startTimeEpoch" : 1425081758277, "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json new file mode 100644 index 0000000000000..74a7b40a59272 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_and_maxEndDate_app_list_json_expectation.json @@ -0,0 +1,57 @@ +[ { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950 + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890 + } ] +}, { + "id" : "local-1426533911241", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-03-17T23:11:50.242GMT", + "endTime" : "2015-03-17T23:12:25.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177 + }, { + "attemptId" : "1", + "startTime" : "2015-03-16T19:25:10.242GMT", + "endTime" : "2015-03-16T19:25:45.177GMT", + "lastUpdated" : "", + "duration" : 34935, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177 + } ] +} ] \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json new file mode 100644 index 0000000000000..7f896c74b5be1 --- /dev/null +++ b/core/src/test/resources/HistoryServerExpectations/minEndDate_app_list_json_expectation.json @@ -0,0 +1,74 @@ +[ { + "id" : "app-20161116163331-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-16T22:33:29.916GMT", + "endTime" : "2016-11-16T22:33:40.587GMT", + "lastUpdated" : "", + "duration" : 10671, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "startTimeEpoch" : 1479335609916, + "lastUpdatedEpoch" : 0, + "endTimeEpoch" : 1479335620587 + } ] +}, { + "id" : "app-20161115172038-0000", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2016-11-15T23:20:37.079GMT", + "endTime" : "2016-11-15T23:22:18.874GMT", + "lastUpdated" : "", + "duration" : 101795, + "sparkUser" : "jose", + "completed" : true, + "appSparkVersion" : "2.1.0-SNAPSHOT", + "startTimeEpoch" : 1479252037079, + "lastUpdatedEpoch" : 0, + "endTimeEpoch" : 1479252138874 + } ] +}, { + "id" : "local-1430917381534", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:11.398GMT", + "lastUpdated" : "", + "duration" : 10505, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398 + } ] +}, { + "id" : "local-1430917381535", + "name" : "Spark shell", + "attempts" : [ { + "attemptId" : "2", + "startTime" : "2015-05-06T13:03:00.893GMT", + "endTime" : "2015-05-06T13:03:00.950GMT", + "lastUpdated" : "", + "duration" : 57, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950 + }, { + "attemptId" : "1", + "startTime" : "2015-05-06T13:03:00.880GMT", + "endTime" : "2015-05-06T13:03:00.890GMT", + "lastUpdated" : "", + "duration" : 10, + "sparkUser" : "irashid", + "completed" : true, + "appSparkVersion" : "1.4.0-SNAPSHOT", + "lastUpdatedEpoch" : 0, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890 + } ] +} ] \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index f486d46313d8b..24ec6a163fc2c 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -8,8 +8,9 @@ "duration" : 9011, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1422981779720, + "appSparkVersion" : "", "endTimeEpoch" : 1422981788731, + "startTimeEpoch" : 1422981779720, "lastUpdatedEpoch" : 0 } ] } diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index e63039f6a17fc..94b6d6dba76e9 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -9,8 +9,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426633910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426633945177, + "startTimeEpoch" : 1426633910242, "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", @@ -20,8 +21,9 @@ "duration" : 34935, "sparkUser" : "irashid", "completed" : true, - "startTimeEpoch" : 1426533910242, + "appSparkVersion" : "", "endTimeEpoch" : 1426533945177, + "startTimeEpoch" : 1426533910242, "lastUpdatedEpoch" : 0 } ] } diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index 0084339d24642..c2f450ba87c6d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -30,8 +30,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.829GMT", + "duration" : 435, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -68,24 +70,26 @@ } } }, - "11" : { - "taskId" : 11, - "index" : 3, + "9" : { + "taskId" : 9, + "index" : 1, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 436, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 436, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -105,19 +109,21 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1647, - "writeTime" : 83000, + "bytesWritten" : 1648, + "writeTime" : 98000, "recordsWritten" : 0 } } }, - "14" : { - "taskId" : 14, - "index" : 6, + "10" : { + "taskId" : 10, + "index" : 2, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.832GMT", + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -149,18 +155,20 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 88000, + "writeTime" : 76000, "recordsWritten" : 0 } } }, - "13" : { - "taskId" : 13, - "index" : 5, + "11" : { + "taskId" : 11, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,7 +179,7 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -191,19 +199,21 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 73000, + "bytesWritten" : 1647, + "writeTime" : 83000, "recordsWritten" : 0 } } }, - "10" : { - "taskId" : 10, - "index" : 2, + "12" : { + "taskId" : 12, + "index" : 4, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -234,30 +244,32 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 76000, + "bytesWritten" : 1645, + "writeTime" : 101000, "recordsWritten" : 0 } } }, - "9" : { - "taskId" : 9, - "index" : 1, + "13" : { + "taskId" : 13, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 436, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 0, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -278,18 +290,20 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 98000, + "writeTime" : 73000, "recordsWritten" : 0 } } }, - "12" : { - "taskId" : 12, - "index" : 4, + "14" : { + "taskId" : 14, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", + "launchTime" : "2015-02-03T16:43:05.832GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -320,8 +334,8 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1645, - "writeTime" : 101000, + "bytesWritten" : 1648, + "writeTime" : 88000, "recordsWritten" : 0 } } @@ -331,8 +345,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.833GMT", + "duration" : 435, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 63fe3b2f958e5..506859ae545b1 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -30,8 +30,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.829GMT", + "duration" : 435, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -68,24 +70,26 @@ } } }, - "11" : { - "taskId" : 11, - "index" : 3, + "9" : { + "taskId" : 9, + "index" : 1, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 436, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 2, + "executorDeserializeTime" : 1, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 434, + "executorRunTime" : 436, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 1, + "resultSerializationTime" : 0, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -105,19 +109,21 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1647, - "writeTime" : 83000, + "bytesWritten" : 1648, + "writeTime" : 98000, "recordsWritten" : 0 } } }, - "14" : { - "taskId" : 14, - "index" : 6, + "10" : { + "taskId" : 10, + "index" : 2, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.832GMT", + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -149,18 +155,20 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 88000, + "writeTime" : 76000, "recordsWritten" : 0 } } }, - "13" : { - "taskId" : 13, - "index" : 5, + "11" : { + "taskId" : 11, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", + "launchTime" : "2015-02-03T16:43:05.830GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,7 +179,7 @@ "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -191,19 +199,21 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 73000, + "bytesWritten" : 1647, + "writeTime" : 83000, "recordsWritten" : 0 } } }, - "10" : { - "taskId" : 10, - "index" : 2, + "12" : { + "taskId" : 12, + "index" : 4, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -234,30 +244,32 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1648, - "writeTime" : 76000, + "bytesWritten" : 1645, + "writeTime" : 101000, "recordsWritten" : 0 } } }, - "9" : { - "taskId" : 9, - "index" : 1, + "13" : { + "taskId" : 13, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.830GMT", + "launchTime" : "2015-02-03T16:43:05.831GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], "taskMetrics" : { - "executorDeserializeTime" : 1, + "executorDeserializeTime" : 2, "executorDeserializeCpuTime" : 0, - "executorRunTime" : 436, + "executorRunTime" : 434, "executorCpuTime" : 0, "resultSize" : 1902, "jvmGcTime" : 19, - "resultSerializationTime" : 0, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -278,18 +290,20 @@ }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, - "writeTime" : 98000, + "writeTime" : 73000, "recordsWritten" : 0 } } }, - "12" : { - "taskId" : 12, - "index" : 4, + "14" : { + "taskId" : 14, + "index" : 6, "attempt" : 0, - "launchTime" : "2015-02-03T16:43:05.831GMT", + "launchTime" : "2015-02-03T16:43:05.832GMT", + "duration" : 434, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -320,8 +334,8 @@ "recordsRead" : 0 }, "shuffleWriteMetrics" : { - "bytesWritten" : 1645, - "writeTime" : 101000, + "bytesWritten" : 1648, + "writeTime" : 88000, "recordsWritten" : 0 } } @@ -331,8 +345,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-02-03T16:43:05.833GMT", + "duration" : 435, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index e0661c464179d..f4cec68fbfdf2 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -3,8 +3,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -45,8 +47,10 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 350, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -87,8 +91,10 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", + "duration" : 348, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -129,8 +135,10 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,8 +179,10 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -213,8 +223,10 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 350, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -255,8 +267,10 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 351, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -297,8 +311,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -339,8 +355,10 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", + "duration" : 80, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -381,8 +399,10 @@ "index" : 9, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -423,8 +443,10 @@ "index" : 10, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.916GMT", + "duration" : 73, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -465,8 +487,10 @@ "index" : 11, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.918GMT", + "duration" : 75, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -507,8 +531,10 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", + "duration" : 77, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -549,8 +575,10 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", + "duration" : 76, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -591,8 +619,10 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -633,8 +663,10 @@ "index" : 15, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.928GMT", + "duration" : 76, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -675,8 +707,10 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -717,8 +751,10 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", + "duration" : 91, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -759,8 +795,10 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", + "duration" : 92, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -801,8 +839,10 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index 8492f19ab7a5f..496a21c328da9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -3,8 +3,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.515GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -50,8 +52,10 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.521GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -97,8 +101,10 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -144,8 +150,10 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -191,8 +199,10 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -238,8 +248,10 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -285,8 +297,10 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -332,8 +346,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.524GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 4de4c501a43ad..4328dc753c5d4 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -3,8 +3,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.515GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -50,8 +52,10 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.521GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -97,8 +101,10 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -144,8 +150,10 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -191,8 +199,10 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -238,8 +248,10 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.523GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -285,8 +297,10 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.523GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -332,8 +346,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-03-17T23:12:16.524GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index d2eceeb3f97a9..8c571430f3a1f 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -3,8 +3,10 @@ "index" : 10, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.916GMT", + "duration" : 73, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -45,8 +47,10 @@ "index" : 11, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.918GMT", + "duration" : 75, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -87,8 +91,10 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", + "duration" : 77, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -129,8 +135,10 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", + "duration" : 76, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,8 +179,10 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -213,8 +223,10 @@ "index" : 15, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.928GMT", + "duration" : 76, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -255,8 +267,10 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -297,8 +311,10 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", + "duration" : 91, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -339,8 +355,10 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", + "duration" : 92, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -381,8 +399,10 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -423,8 +443,10 @@ "index" : 20, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -465,8 +487,10 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", + "duration" : 88, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -507,8 +531,10 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", + "duration" : 93, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -549,8 +575,10 @@ "index" : 23, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.031GMT", + "duration" : 65, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -591,8 +619,10 @@ "index" : 24, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.098GMT", + "duration" : 43, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -633,8 +663,10 @@ "index" : 25, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.103GMT", + "duration" : 49, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -675,8 +707,10 @@ "index" : 26, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.105GMT", + "duration" : 38, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -717,8 +751,10 @@ "index" : 27, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.110GMT", + "duration" : 32, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -759,8 +795,10 @@ "index" : 28, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.113GMT", + "duration" : 29, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -801,8 +839,10 @@ "index" : 29, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.114GMT", + "duration" : 39, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -843,8 +883,10 @@ "index" : 30, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.118GMT", + "duration" : 34, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -885,8 +927,10 @@ "index" : 31, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.127GMT", + "duration" : 24, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -927,8 +971,10 @@ "index" : 32, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.148GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -969,8 +1015,10 @@ "index" : 33, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.149GMT", + "duration" : 43, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1011,8 +1059,10 @@ "index" : 34, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.156GMT", + "duration" : 27, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1053,8 +1103,10 @@ "index" : 35, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.161GMT", + "duration" : 35, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1095,8 +1147,10 @@ "index" : 36, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.164GMT", + "duration" : 29, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1137,8 +1191,10 @@ "index" : 37, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.165GMT", + "duration" : 32, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1179,8 +1235,10 @@ "index" : 38, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.166GMT", + "duration" : 31, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1221,8 +1279,10 @@ "index" : 39, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.180GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1263,8 +1323,10 @@ "index" : 40, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.197GMT", + "duration" : 14, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1305,8 +1367,10 @@ "index" : 41, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.200GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1347,8 +1411,10 @@ "index" : 42, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.203GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1389,8 +1455,10 @@ "index" : 43, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.204GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1431,8 +1499,10 @@ "index" : 44, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.205GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1473,8 +1543,10 @@ "index" : 45, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.206GMT", + "duration" : 19, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1515,8 +1587,10 @@ "index" : 46, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.210GMT", + "duration" : 31, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1557,8 +1631,10 @@ "index" : 47, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.212GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1599,8 +1675,10 @@ "index" : 48, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.220GMT", + "duration" : 24, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1641,8 +1719,10 @@ "index" : 49, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.223GMT", + "duration" : 23, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1683,8 +1763,10 @@ "index" : 50, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.240GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1725,8 +1807,10 @@ "index" : 51, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.242GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1767,8 +1851,10 @@ "index" : 52, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.243GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1809,8 +1895,10 @@ "index" : 53, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.244GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1851,8 +1939,10 @@ "index" : 54, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.244GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1893,8 +1983,10 @@ "index" : 55, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.246GMT", + "duration" : 21, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1935,8 +2027,10 @@ "index" : 56, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.249GMT", + "duration" : 20, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -1977,8 +2071,10 @@ "index" : 57, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.257GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -2019,8 +2115,10 @@ "index" : 58, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.263GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -2061,8 +2159,10 @@ "index" : 59, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.265GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index f42c3a4ee5c38..0bd614bdc756e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -3,8 +3,10 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 351, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -45,8 +47,10 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 350, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -87,8 +91,10 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 350, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -129,8 +135,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,8 +179,10 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -213,8 +223,10 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -255,8 +267,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -297,8 +311,10 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", + "duration" : 348, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -339,8 +355,10 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", + "duration" : 93, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -381,8 +399,10 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", + "duration" : 92, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -423,8 +443,10 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", + "duration" : 91, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -465,8 +487,10 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", + "duration" : 88, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -507,8 +531,10 @@ "index" : 9, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -549,8 +575,10 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -591,8 +619,10 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -633,8 +663,10 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -675,8 +707,10 @@ "index" : 20, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -717,8 +751,10 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", + "duration" : 80, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -759,8 +795,10 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", + "duration" : 77, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -801,8 +839,10 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", + "duration" : 76, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index f42c3a4ee5c38..0bd614bdc756e 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -3,8 +3,10 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 351, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -45,8 +47,10 @@ "index" : 1, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.502GMT", + "duration" : 350, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -87,8 +91,10 @@ "index" : 5, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.505GMT", + "duration" : 350, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -129,8 +135,10 @@ "index" : 0, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.494GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,8 +179,10 @@ "index" : 3, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -213,8 +223,10 @@ "index" : 4, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.504GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -255,8 +267,10 @@ "index" : 7, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.506GMT", + "duration" : 349, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -297,8 +311,10 @@ "index" : 2, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.503GMT", + "duration" : 348, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -339,8 +355,10 @@ "index" : 22, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.018GMT", + "duration" : 93, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -381,8 +399,10 @@ "index" : 18, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.010GMT", + "duration" : 92, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -423,8 +443,10 @@ "index" : 17, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.005GMT", + "duration" : 91, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -465,8 +487,10 @@ "index" : 21, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.015GMT", + "duration" : 88, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -507,8 +531,10 @@ "index" : 9, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.915GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -549,8 +575,10 @@ "index" : 16, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.001GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -591,8 +619,10 @@ "index" : 19, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.012GMT", + "duration" : 84, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -633,8 +663,10 @@ "index" : 14, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.925GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -675,8 +707,10 @@ "index" : 20, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.014GMT", + "duration" : 83, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -717,8 +751,10 @@ "index" : 8, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.914GMT", + "duration" : 80, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -759,8 +795,10 @@ "index" : 12, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.923GMT", + "duration" : 77, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -801,8 +839,10 @@ "index" : 13, "attempt" : 0, "launchTime" : "2015-05-06T13:03:06.924GMT", + "duration" : 76, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index db60ccccbf8c8..b58f1a51ba481 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -3,8 +3,10 @@ "index" : 40, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.197GMT", + "duration" : 14, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -45,8 +47,10 @@ "index" : 41, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.200GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -87,8 +91,10 @@ "index" : 43, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.204GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -129,8 +135,10 @@ "index" : 57, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.257GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -171,8 +179,10 @@ "index" : 58, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.263GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -213,8 +223,10 @@ "index" : 68, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.306GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -255,8 +267,10 @@ "index" : 86, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.374GMT", + "duration" : 16, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -297,8 +311,10 @@ "index" : 32, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.148GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -339,8 +355,10 @@ "index" : 39, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.180GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -381,8 +399,10 @@ "index" : 42, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.203GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -423,8 +443,10 @@ "index" : 51, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.242GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -465,8 +487,10 @@ "index" : 59, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.265GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -507,8 +531,10 @@ "index" : 63, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.276GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -549,8 +575,10 @@ "index" : 87, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.374GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -591,8 +619,10 @@ "index" : 90, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.385GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -633,8 +663,10 @@ "index" : 99, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.426GMT", + "duration" : 17, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -675,8 +707,10 @@ "index" : 44, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.205GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -717,8 +751,10 @@ "index" : 47, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.212GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -759,8 +795,10 @@ "index" : 50, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.240GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], @@ -801,8 +839,10 @@ "index" : 52, "attempt" : 0, "launchTime" : "2015-05-06T13:03:07.243GMT", + "duration" : 18, "executorId" : "driver", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index 5dcbc890438b2..0ed609d5b7f92 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -3,7 +3,7 @@ "executorDeserializeTime" : [ 1.0, 3.0, 36.0 ], "executorDeserializeCpuTime" : [ 0.0, 0.0, 0.0 ], "executorRunTime" : [ 16.0, 28.0, 351.0 ], - "executorCpuTime" : [ 0.0, 0.0, 0.0], + "executorCpuTime" : [ 0.0, 0.0, 0.0 ], "resultSize" : [ 2010.0, 2065.0, 2065.0 ], "jvmGcTime" : [ 0.0, 0.0, 7.0 ], "resultSerializationTime" : [ 0.0, 0.0, 2.0 ], diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index aaeef1f2f582c..a449926ee7dc6 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -29,23 +29,25 @@ "value" : "5050" } ], "tasks" : { - "2" : { - "taskId" : 2, - "index" : 2, + "0" : { + "taskId" : 0, + "index" : 0, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", + "launchTime" : "2015-03-16T19:25:36.515GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "378", - "value" : "378" + "update" : "78", + "value" : "5050" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 14, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -77,23 +79,25 @@ } } }, - "5" : { - "taskId" : 5, - "index" : 5, + "1" : { + "taskId" : 1, + "index" : 1, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.523GMT", + "launchTime" : "2015-03-16T19:25:36.521GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "897", - "value" : "3750" + "update" : "247", + "value" : "2175" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 14, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -125,29 +129,31 @@ } } }, - "4" : { - "taskId" : 4, - "index" : 4, + "2" : { + "taskId" : 2, + "index" : 2, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "678", - "value" : "2853" + "update" : "378", + "value" : "378" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 1, + "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -173,23 +179,25 @@ } } }, - "7" : { - "taskId" : 7, - "index" : 7, + "3" : { + "taskId" : 3, + "index" : 3, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.524GMT", + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "1222", - "value" : "4972" + "update" : "572", + "value" : "950" } ], "taskMetrics" : { - "executorDeserializeTime" : 12, + "executorDeserializeTime" : 13, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -221,29 +229,31 @@ } } }, - "1" : { - "taskId" : 1, - "index" : 1, + "4" : { + "taskId" : 4, + "index" : 4, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.521GMT", + "launchTime" : "2015-03-16T19:25:36.522GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "247", - "value" : "2175" + "update" : "678", + "value" : "2853" } ], "taskMetrics" : { - "executorDeserializeTime" : 14, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, "resultSize" : 697, "jvmGcTime" : 0, - "resultSerializationTime" : 2, + "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, "diskBytesSpilled" : 0, "inputMetrics" : { @@ -269,23 +279,25 @@ } } }, - "3" : { - "taskId" : 3, - "index" : 3, + "5" : { + "taskId" : 5, + "index" : 5, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.522GMT", + "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "572", - "value" : "950" + "update" : "897", + "value" : "3750" } ], "taskMetrics" : { - "executorDeserializeTime" : 13, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, @@ -322,8 +334,10 @@ "index" : 6, "attempt" : 0, "launchTime" : "2015-03-16T19:25:36.523GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { @@ -365,23 +379,25 @@ } } }, - "0" : { - "taskId" : 0, - "index" : 0, + "7" : { + "taskId" : 7, + "index" : 7, "attempt" : 0, - "launchTime" : "2015-03-16T19:25:36.515GMT", + "launchTime" : "2015-03-16T19:25:36.524GMT", + "duration" : 15, "executorId" : "", "host" : "localhost", + "status" : "SUCCESS", "taskLocality" : "PROCESS_LOCAL", "speculative" : false, "accumulatorUpdates" : [ { "id" : 1, "name" : "my counter", - "update" : "78", - "value" : "5050" + "update" : "1222", + "value" : "4972" } ], "taskMetrics" : { - "executorDeserializeTime" : 14, + "executorDeserializeTime" : 12, "executorDeserializeCpuTime" : 0, "executorRunTime" : 15, "executorCpuTime" : 0, diff --git a/core/src/test/resources/fairscheduler-with-invalid-data.xml b/core/src/test/resources/fairscheduler-with-invalid-data.xml new file mode 100644 index 0000000000000..a4d8d07b67ce4 --- /dev/null +++ b/core/src/test/resources/fairscheduler-with-invalid-data.xml @@ -0,0 +1,80 @@ + + + + + + INVALID_MIN_SHARE + 2 + FAIR + + + 1 + INVALID_WEIGHT + FAIR + + + 3 + 2 + INVALID_SCHEDULING_MODE + + + 2 + 1 + fair + + + 1 + 2 + NONE + + + + 2 + FAIR + + + 1 + + FAIR + + + 3 + 2 + + + + + 3 + FAIR + + + 2 + + FAIR + + + 2 + 2 + + + + 3 + 2 + FAIR + + diff --git a/core/src/test/resources/spark-events/app-20161115172038-0000 b/core/src/test/resources/spark-events/app-20161115172038-0000 new file mode 100755 index 0000000000000..3af0451d0c392 --- /dev/null +++ b/core/src/test/resources/spark-events/app-20161115172038-0000 @@ -0,0 +1,75 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.1.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.111","Port":64527},"Maximum Memory":384093388,"Timestamp":1479252038836} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","Java Version":"1.8.0_92 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.task.maxTaskAttemptsPerExecutor":"3","spark.blacklist.enabled":"TRUE","spark.driver.host":"172.22.0.111","spark.blacklist.task.maxTaskAttemptsPerNode":"3","spark.eventLog.enabled":"TRUE","spark.driver.port":"64511","spark.repl.class.uri":"spark://172.22.0.111:64511/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-f09ef9e2-7f15-433f-a5d1-30138d8764ca/repl-28d60911-dbc3-465f-b7b3-ee55c071595e","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"3","spark.scheduler.mode":"FIFO","spark.eventLog.overwrite":"TRUE","spark.blacklist.stage.maxFailedTasksPerExecutor":"3","spark.executor.id":"driver","spark.blacklist.application.maxFailedExecutorsPerNode":"2","spark.submit.deployMode":"client","spark.master":"local-cluster[4,4,1024]","spark.home":"/Users/Jose/IdeaProjects/spark","spark.eventLog.dir":"/Users/jose/logs","spark.sql.catalogImplementation":"in-memory","spark.eventLog.compress":"FALSE","spark.blacklist.application.maxFailedTasksPerExecutor":"1","spark.blacklist.timeout":"10000","spark.app.id":"app-20161115172038-0000","spark.task.maxFailures":"4"},"System Properties":{"java.io.tmpdir":"/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/Jose","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib","user.dir":"/Users/Jose/IdeaProjects/spark","java.library.path":"/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.92-b14","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_92-b14","java.vm.info":"mixed mode","java.ext.dirs":"/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","io.netty.maxDirectMemory":"0","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.11.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"jose","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=10000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","java.version":"1.8.0_92","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/conf/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/jars/*":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20161115172038-0000","Timestamp":1479252037079,"User":"jose"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479252042589,"Executor ID":"2","Executor Info":{"Host":"172.22.0.111","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stdout","stderr":"http://172.22.0.111:64519/logPage/?appId=app-20161115172038-0000&executorId=2&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479252042593,"Executor ID":"0","Executor Info":{"Host":"172.22.0.111","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stdout","stderr":"http://172.22.0.111:64517/logPage/?appId=app-20161115172038-0000&executorId=0&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479252042629,"Executor ID":"1","Executor Info":{"Host":"172.22.0.111","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stdout","stderr":"http://172.22.0.111:64518/logPage/?appId=app-20161115172038-0000&executorId=1&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.111","Port":64540},"Maximum Memory":384093388,"Timestamp":1479252042687} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.111","Port":64539},"Maximum Memory":384093388,"Timestamp":1479252042689} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.111","Port":64541},"Maximum Memory":384093388,"Timestamp":1479252042692} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479252042711,"Executor ID":"3","Executor Info":{"Host":"172.22.0.111","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stdout","stderr":"http://172.22.0.111:64521/logPage/?appId=app-20161115172038-0000&executorId=3&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.111","Port":64543},"Maximum Memory":384093388,"Timestamp":1479252042759} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1479252043855,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0],"Properties":{}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},"Properties":{}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479252044021,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1479252044052,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1479252044052,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1479252044053,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1479252044054,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1479252044055,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1479252044055,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":7,"Attempt":0,"Launch Time":1479252044056,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":8,"Attempt":0,"Launch Time":1479252044056,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":9,"Attempt":0,"Launch Time":1479252044057,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":10,"Attempt":0,"Launch Time":1479252044058,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":11,"Attempt":0,"Launch Time":1479252044058,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":12,"Attempt":0,"Launch Time":1479252044059,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":13,"Attempt":0,"Launch Time":1479252044060,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":14,"Attempt":0,"Launch Time":1479252044064,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":15,"Attempt":0,"Launch Time":1479252044065,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":12,"Attempt":0,"Launch Time":1479252044059,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044653,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":499,"Value":499,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":52390000,"Value":52390000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":18,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":7909000,"Value":7909000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1123,"Value":1123,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":21,"Value":21,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":499,"Executor Deserialize CPU Time":52390000,"Executor Run Time":18,"Executor CPU Time":7909000,"Result Size":1123,"JVM GC Time":21,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1479252044054,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044657,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":508,"Value":1007,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":36827000,"Value":89217000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":18,"Value":36,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":4333000,"Value":12242000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1123,"Value":2246,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":21,"Value":42,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":508,"Executor Deserialize CPU Time":36827000,"Executor Run Time":18,"Executor CPU Time":4333000,"Result Size":1123,"JVM GC Time":21,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":8,"Attempt":0,"Launch Time":1479252044056,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044658,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":509,"Value":1516,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":44100000,"Value":133317000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":17,"Value":53,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":11340000,"Value":23582000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1123,"Value":3369,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":21,"Value":63,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":509,"Executor Deserialize CPU Time":44100000,"Executor Run Time":17,"Executor CPU Time":11340000,"Result Size":1123,"JVM GC Time":21,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479252044021,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044692,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":511,"Value":2027,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":227762000,"Value":361079000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":69,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":3631000,"Value":27213000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1938,"Value":5307,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":21,"Value":84,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":2,"Value":5,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":511,"Executor Deserialize CPU Time":227762000,"Executor Run Time":16,"Executor CPU Time":3631000,"Result Size":1938,"JVM GC Time":21,"Result Serialization Time":2,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":495,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1479252044055,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044720,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":495,"Value":564,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Value":114,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":495,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":30,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":494,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1479252044052,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044727,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":494,"Value":1058,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Value":144,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":494,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":30,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":494,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":13,"Index":13,"Attempt":0,"Launch Time":1479252044060,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044729,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":494,"Value":1552,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Value":174,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":494,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":30,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":13,"Attempt":1,"Launch Time":1479252044731,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":1,"Attempt":1,"Launch Time":1479252044731,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":5,"Attempt":1,"Launch Time":1479252044732,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":451,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":11,"Index":11,"Attempt":0,"Launch Time":1479252044058,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044736,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":451,"Value":2003,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Value":206,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":451,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":32,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":11,"Attempt":1,"Launch Time":1479252044736,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":446,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":15,"Index":15,"Attempt":0,"Launch Time":1479252044065,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044737,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":446,"Value":2449,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Value":238,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":446,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":32,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":15,"Attempt":1,"Launch Time":1479252044737,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":448,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":7,"Index":7,"Attempt":0,"Launch Time":1479252044056,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044741,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":448,"Value":2897,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Value":270,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":448,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":32,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":7,"Attempt":1,"Launch Time":1479252044742,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044752,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":1,"Attempt":1,"Launch Time":1479252044731,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044748,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":8,"Value":2035,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3655000,"Value":364734000,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":899000,"Value":28112000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":884,"Value":6191,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":8,"Executor Deserialize CPU Time":3655000,"Executor Run Time":0,"Executor CPU Time":899000,"Result Size":884,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":19,"Index":11,"Attempt":1,"Launch Time":1479252044736,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044749,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":2899,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":2,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":11,"Attempt":2,"Launch Time":1479252044749,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":7,"Attempt":1,"Launch Time":1479252044742,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044752,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":2038,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3566000,"Value":368300000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":2900,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1004000,"Value":29116000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":7154,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":3566000,"Executor Run Time":1,"Executor CPU Time":1004000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":20,"Index":15,"Attempt":1,"Launch Time":1479252044737,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044756,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":10,"Value":2910,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":10,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":15,"Attempt":2,"Launch Time":1479252044756,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":11,"Attempt":2,"Launch Time":1479252044749,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044759,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":2042,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3720000,"Value":372020000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":2911,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1009000,"Value":30125000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":8117,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3720000,"Executor Run Time":1,"Executor CPU Time":1009000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":5,"Attempt":1,"Launch Time":1479252044732,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044760,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":2047,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4303000,"Value":376323000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":2913,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":999000,"Value":31124000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":9080,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":4303000,"Executor Run Time":2,"Executor CPU Time":999000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":15,"Attempt":2,"Launch Time":1479252044756,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044768,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":6,"Value":2053,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":4946000,"Value":381269000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":2914,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1176000,"Value":32300000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":10043,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":6,"Executor Deserialize CPU Time":4946000,"Executor Run Time":1,"Executor CPU Time":1176000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":13,"Attempt":1,"Launch Time":1479252044731,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044775,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":7,"Value":2060,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3406000,"Value":384675000,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1007000,"Value":33307000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":971,"Value":11014,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":6,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":7,"Executor Deserialize CPU Time":3406000,"Executor Run Time":0,"Executor CPU Time":1007000,"Result Size":971,"JVM GC Time":0,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":456,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1479252044053,"Executor ID":"2","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044778,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":456,"Value":3370,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":32,"Value":302,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":456,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":32,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":24,"Index":3,"Attempt":1,"Launch Time":1479252044778,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":503,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":9,"Index":9,"Attempt":0,"Launch Time":1479252044057,"Executor ID":"0","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044789,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":503,"Value":3873,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":30,"Value":332,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":503,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":30,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":25,"Index":9,"Attempt":1,"Launch Time":1479252044789,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":24,"Index":3,"Attempt":1,"Launch Time":1479252044778,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044791,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":2065,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2950000,"Value":387625000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3875,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":822000,"Value":34129000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":11977,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":2950000,"Executor Run Time":2,"Executor CPU Time":822000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":25,"Index":9,"Attempt":1,"Launch Time":1479252044789,"Executor ID":"1","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044798,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":2068,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":2604000,"Value":390229000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3876,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":845000,"Value":34974000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":12940,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":2604000,"Executor Run Time":1,"Executor CPU Time":845000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1479252044055,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044920,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":784,"Value":2852,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":56180000,"Value":446409000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":24,"Value":3900,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6046000,"Value":41020000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1036,"Value":13976,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":350,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":784,"Executor Deserialize CPU Time":56180000,"Executor Run Time":24,"Executor CPU Time":6046000,"Result Size":1036,"JVM GC Time":18,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1479252044052,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044921,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":789,"Value":3641,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":34766000,"Value":481175000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":22,"Value":3922,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":8189000,"Value":49209000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1036,"Value":15012,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":368,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":789,"Executor Deserialize CPU Time":34766000,"Executor Run Time":22,"Executor CPU Time":8189000,"Result Size":1036,"JVM GC Time":18,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":14,"Index":14,"Attempt":0,"Launch Time":1479252044064,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044921,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":777,"Value":4418,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":29960000,"Value":511135000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":24,"Value":3946,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9708000,"Value":58917000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1036,"Value":16048,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":386,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":777,"Executor Deserialize CPU Time":29960000,"Executor Run Time":24,"Executor CPU Time":9708000,"Result Size":1036,"JVM GC Time":18,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":10,"Index":10,"Attempt":0,"Launch Time":1479252044058,"Executor ID":"3","Host":"172.22.0.111","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479252044924,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":791,"Value":5209,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":266560000,"Value":777695000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":3962,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":5884000,"Value":64801000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1851,"Value":17899,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":404,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":791,"Executor Deserialize CPU Time":266560000,"Executor Run Time":16,"Executor CPU Time":5884000,"Result Size":1851,"JVM GC Time":18,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1479252044017,"Completion Time":1479252044926,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Value":3962,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":404,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":17899,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":777695000,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":64801000,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":6,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":5209,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1479252044931,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479252044930,"executorId":"2","taskFailures":4} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479252044930,"executorId":"0","taskFailures":4} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklisted","time":1479252044930,"hostId":"172.22.0.111","executorFailures":2} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorUnblacklisted","time":1479252055635,"executorId":"2"} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorUnblacklisted","time":1479252055635,"executorId":"0"} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeUnblacklisted","time":1479252055635,"hostId":"172.22.0.111"} +{"Event":"SparkListenerApplicationEnd","Timestamp":1479252138874} diff --git a/core/src/test/resources/spark-events/app-20161116163331-0000 b/core/src/test/resources/spark-events/app-20161116163331-0000 new file mode 100755 index 0000000000000..57cfc5b973129 --- /dev/null +++ b/core/src/test/resources/spark-events/app-20161116163331-0000 @@ -0,0 +1,68 @@ +{"Event":"SparkListenerLogStart","Spark Version":"2.1.0-SNAPSHOT"} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"driver","Host":"172.22.0.167","Port":51475},"Maximum Memory":908381388,"Timestamp":1479335611477,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerEnvironmentUpdate","JVM Information":{"Java Home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","Java Version":"1.8.0_92 (Oracle Corporation)","Scala Version":"version 2.11.8"},"Spark Properties":{"spark.blacklist.task.maxTaskAttemptsPerExecutor":"3","spark.blacklist.enabled":"TRUE","spark.driver.host":"172.22.0.167","spark.blacklist.task.maxTaskAttemptsPerNode":"3","spark.eventLog.enabled":"TRUE","spark.driver.port":"51459","spark.repl.class.uri":"spark://172.22.0.167:51459/classes","spark.jars":"","spark.repl.class.outputDir":"/private/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/spark-1cbc97d0-7fe6-4c9f-8c2c-f6fe51ee3cf2/repl-39929169-ac4c-4c6d-b116-f648e4dd62ed","spark.app.name":"Spark shell","spark.blacklist.stage.maxFailedExecutorsPerNode":"3","spark.scheduler.mode":"FIFO","spark.eventLog.overwrite":"TRUE","spark.blacklist.stage.maxFailedTasksPerExecutor":"3","spark.executor.id":"driver","spark.blacklist.application.maxFailedExecutorsPerNode":"2","spark.submit.deployMode":"client","spark.master":"local-cluster[4,4,1024]","spark.home":"/Users/Jose/IdeaProjects/spark","spark.eventLog.dir":"/Users/jose/logs","spark.sql.catalogImplementation":"in-memory","spark.eventLog.compress":"FALSE","spark.blacklist.application.maxFailedTasksPerExecutor":"1","spark.blacklist.timeout":"1000000","spark.app.id":"app-20161116163331-0000","spark.task.maxFailures":"4"},"System Properties":{"java.io.tmpdir":"/var/folders/l4/d46wlzj16593f3d812vk49tw0000gp/T/","line.separator":"\n","path.separator":":","sun.management.compiler":"HotSpot 64-Bit Tiered Compilers","SPARK_SUBMIT":"true","sun.cpu.endian":"little","java.specification.version":"1.8","java.vm.specification.name":"Java Virtual Machine Specification","java.vendor":"Oracle Corporation","java.vm.specification.version":"1.8","user.home":"/Users/Jose","file.encoding.pkg":"sun.io","sun.nio.ch.bugLevel":"","ftp.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","sun.arch.data.model":"64","sun.boot.library.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib","user.dir":"/Users/Jose/IdeaProjects/spark","java.library.path":"/Users/Jose/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.","sun.cpu.isalist":"","os.arch":"x86_64","java.vm.version":"25.92-b14","java.endorsed.dirs":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/endorsed","java.runtime.version":"1.8.0_92-b14","java.vm.info":"mixed mode","java.ext.dirs":"/Users/Jose/Library/Java/Extensions:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/ext:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java","java.runtime.name":"Java(TM) SE Runtime Environment","file.separator":"/","io.netty.maxDirectMemory":"0","java.class.version":"52.0","scala.usejavacp":"true","java.specification.name":"Java Platform API Specification","sun.boot.class.path":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/resources.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/rt.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/sunrsasign.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jsse.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jce.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/charsets.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/lib/jfr.jar:/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre/classes","file.encoding":"UTF-8","user.timezone":"America/Chicago","java.specification.vendor":"Oracle Corporation","sun.java.launcher":"SUN_STANDARD","os.version":"10.11.6","sun.os.patch.level":"unknown","gopherProxySet":"false","java.vm.specification.vendor":"Oracle Corporation","user.country":"US","sun.jnu.encoding":"UTF-8","http.nonProxyHosts":"local|*.local|169.254/16|*.169.254/16","user.language":"en","socksNonProxyHosts":"local|*.local|169.254/16|*.169.254/16","java.vendor.url":"http://java.oracle.com/","java.awt.printerjob":"sun.lwawt.macosx.CPrinterJob","java.awt.graphicsenv":"sun.awt.CGraphicsEnvironment","awt.toolkit":"sun.lwawt.macosx.LWCToolkit","os.name":"Mac OS X","java.vm.vendor":"Oracle Corporation","java.vendor.url.bug":"http://bugreport.sun.com/bugreport/","user.name":"jose","java.vm.name":"Java HotSpot(TM) 64-Bit Server VM","sun.java.command":"org.apache.spark.deploy.SparkSubmit --master local-cluster[4,4,1024] --conf spark.blacklist.enabled=TRUE --conf spark.blacklist.timeout=1000000 --conf spark.blacklist.application.maxFailedTasksPerExecutor=1 --conf spark.eventLog.overwrite=TRUE --conf spark.blacklist.task.maxTaskAttemptsPerNode=3 --conf spark.blacklist.stage.maxFailedTasksPerExecutor=3 --conf spark.blacklist.task.maxTaskAttemptsPerExecutor=3 --conf spark.eventLog.compress=FALSE --conf spark.blacklist.stage.maxFailedExecutorsPerNode=3 --conf spark.eventLog.enabled=TRUE --conf spark.eventLog.dir=/Users/jose/logs --conf spark.blacklist.application.maxFailedExecutorsPerNode=2 --conf spark.task.maxFailures=4 --class org.apache.spark.repl.Main --name Spark shell spark-shell -i /Users/Jose/dev/jose-utils/blacklist/test-blacklist.scala","java.home":"/Library/Java/JavaVirtualMachines/jdk1.8.0_92.jdk/Contents/Home/jre","java.version":"1.8.0_92","sun.io.unicode.encoding":"UnicodeBig"},"Classpath Entries":{"/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-mapred-1.7.7-hadoop2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-core-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlet-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-column-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/snappy-java-1.1.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/oro-2.0.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/arpack_combined_all-0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-schema-1.2.15.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-assembly_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javassist-3.18.1-GA.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-tags_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-launcher_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math3-3.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-api-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-xml_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/objenesis-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire-macros_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-reflect-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib-local_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-mllib_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-server-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-mapper-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-scala_2.11-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-framework-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-client-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-asl-1.9.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-common/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/zookeeper-3.4.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-auth-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/repl/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jul-to-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-media-jaxb-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-io-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/RoaringBitmap-0.5.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.ws.rs-api-2.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/catalyst/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-unsafe_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-repl_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-continuation-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive-thriftserver/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-annotations-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-graphite-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-api-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-core-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/streaming/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-net-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-proxy-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-catalyst_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/lz4-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-crypto-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.annotation-api-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sql_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guava-14.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.servlet-api-3.1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-collections-3.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/conf/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/unused-1.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-encoding-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/tags/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-jackson_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-cli-1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-server-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/cglib-2.2.1-v20090111.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pyrolite-4.13.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-library-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-parser-combinators_2.11-1.0.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-6.1.26.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/py4j-0.10.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-configuration-1.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/core-1.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/core/target/jars/*":"System Classpath","/Users/Jose/IdeaProjects/spark/common/network-shuffle/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-format-2.3.0-incubating.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/kryo-shaded-3.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/core/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill-java-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-annotations-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-hadoop-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/sql/hive/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xz-1.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-jackson-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/aopalliance-repackaged-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-common-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/log4j-1.2.17.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-core-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-util-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalap-2.11.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/osgi-resource-locator-1.0.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-1.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compress-1.4.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jcl-over-slf4j-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/yarn/target/scala-2.11/classes":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-plus-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/protobuf-java-2.5.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/unsafe/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-module-paranamer-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/leveldbjni-all-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-core-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-api-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/compress-lzf-1.0.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/stream-2.7.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-shuffle-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-codec-1.10.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-yarn-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/common/sketch/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-core_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-container-servlet-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-shuffle_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang-2.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/ivy-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-common-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-math-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-hdfs-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scala-compiler-2.11.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-jvm-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-lang3-3.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jsr305-1.3.9.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/minlog-1.3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-3.8.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-webapp-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-ast_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xbean-asm5-shaded-4.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-io-2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/slf4j-log4j12-1.7.16.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-locator-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/shapeless_2.11-2.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-network-common_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-xml-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-httpclient-3.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/javax.inject-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/mllib/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/scalatest_2.11-2.2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hk2-utils-2.4.0-b34.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-client-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-guava-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-jndi-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/graphx/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-app-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/examples/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/xmlenc-0.52.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jets3t-0.7.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/curator-recipes-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/opencsv-2.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jtransforms-2.4.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/antlr4-runtime-4.5.3.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/chill_2.11-0.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-digester-1.8.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/univocity-parsers-2.2.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jline-2.12.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-streaming_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/launcher/target/scala-2.11/classes/":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/breeze-macros_2.11-0.12.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jersey-client-2.22.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jackson-databind-2.6.5.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-servlets-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/paranamer-2.6.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-security-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-ipc-1.7.7-tests.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/avro-1.7.7.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spire_2.11-0.7.4.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-client-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/metrics-json-3.1.2.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-beanutils-core-1.8.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/validation-api-1.1.0.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-graphx_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/netty-all-4.0.41.Final.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/janino-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/json4s-core_2.11-3.2.11.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/commons-compiler-3.0.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/guice-3.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-server-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/jetty-http-9.2.16.v20160414.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/parquet-common-1.8.1.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/hadoop-mapreduce-client-jobclient-2.2.0.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/spark-sketch_2.11-2.1.0-SNAPSHOT.jar":"System Classpath","/Users/Jose/IdeaProjects/spark/assembly/target/scala-2.11/jars/pmml-model-1.2.15.jar":"System Classpath"}} +{"Event":"SparkListenerApplicationStart","App Name":"Spark shell","App ID":"app-20161116163331-0000","Timestamp":1479335609916,"User":"jose"} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479335615320,"Executor ID":"3","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout","stderr":"http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"3","Host":"172.22.0.167","Port":51485},"Maximum Memory":908381388,"Timestamp":1479335615387,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479335615393,"Executor ID":"2","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout","stderr":"http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr"}}} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479335615443,"Executor ID":"1","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout","stderr":"http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"2","Host":"172.22.0.167","Port":51487},"Maximum Memory":908381388,"Timestamp":1479335615448,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerExecutorAdded","Timestamp":1479335615462,"Executor ID":"0","Executor Info":{"Host":"172.22.0.167","Total Cores":4,"Log Urls":{"stdout":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout","stderr":"http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr"}}} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"1","Host":"172.22.0.167","Port":51490},"Maximum Memory":908381388,"Timestamp":1479335615496,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerBlockManagerAdded","Block Manager ID":{"Executor ID":"0","Host":"172.22.0.167","Port":51491},"Maximum Memory":908381388,"Timestamp":1479335615515,"Maximum Onheap Memory":384093388,"Maximum Offheap Memory":524288000} +{"Event":"SparkListenerJobStart","Job ID":0,"Submission Time":1479335616467,"Stage Infos":[{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]}],"Stage IDs":[0],"Properties":{}} +{"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Accumulables":[]},"Properties":{}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479335616657,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1479335616687,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1479335616688,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1479335616688,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1479335616689,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1479335616690,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1479335616691,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":7,"Index":7,"Attempt":0,"Launch Time":1479335616692,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":8,"Index":8,"Attempt":0,"Launch Time":1479335616692,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":9,"Index":9,"Attempt":0,"Launch Time":1479335616693,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":10,"Index":10,"Attempt":0,"Launch Time":1479335616694,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":11,"Index":11,"Attempt":0,"Launch Time":1479335616694,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":12,"Index":12,"Attempt":0,"Launch Time":1479335616695,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":13,"Index":13,"Attempt":0,"Launch Time":1479335616696,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":14,"Index":14,"Attempt":0,"Launch Time":1479335616696,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":15,"Index":15,"Attempt":0,"Launch Time":1479335616697,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":0,"Failed":false,"Killed":false,"Accumulables":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":11,"Index":11,"Attempt":0,"Launch Time":1479335616694,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617253,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":465,"Value":465,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":47305000,"Value":47305000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":22,"Value":22,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":7220000,"Value":7220000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1123,"Value":1123,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":18,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":1,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":465,"Executor Deserialize CPU Time":47305000,"Executor Run Time":22,"Executor CPU Time":7220000,"Result Size":1123,"JVM GC Time":18,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":15,"Index":15,"Attempt":0,"Launch Time":1479335616697,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617257,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":464,"Value":929,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":20082000,"Value":67387000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":21,"Value":43,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9084000,"Value":16304000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1123,"Value":2246,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":36,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":2,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":464,"Executor Deserialize CPU Time":20082000,"Executor Run Time":21,"Executor CPU Time":9084000,"Result Size":1123,"JVM GC Time":18,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":7,"Index":7,"Attempt":0,"Launch Time":1479335616692,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617257,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":468,"Value":1397,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":29183000,"Value":96570000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":21,"Value":64,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":5753000,"Value":22057000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1123,"Value":3369,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":54,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":3,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":468,"Executor Deserialize CPU Time":29183000,"Executor Run Time":21,"Executor CPU Time":5753000,"Result Size":1123,"JVM GC Time":18,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":3,"Index":3,"Attempt":0,"Launch Time":1479335616688,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617257,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":470,"Value":1867,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":233387000,"Value":329957000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":22,"Value":86,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6783000,"Value":28840000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1938,"Value":5307,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":18,"Value":72,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Update":1,"Value":4,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":470,"Executor Deserialize CPU Time":233387000,"Executor Run Time":22,"Executor CPU Time":6783000,"Result Size":1938,"JVM GC Time":18,"Result Serialization Time":1,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":453,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":5,"Index":5,"Attempt":0,"Launch Time":1479335616690,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617319,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":453,"Value":539,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Value":94,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":453,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":22,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":444,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":14,"Index":14,"Attempt":0,"Launch Time":1479335616696,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617326,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":444,"Value":983,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Value":123,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":444,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":29,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":451,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":1,"Index":1,"Attempt":0,"Launch Time":1479335616687,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617327,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":451,"Value":1434,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Value":145,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":451,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":22,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":451,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":13,"Index":13,"Attempt":0,"Launch Time":1479335616696,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617328,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":451,"Value":1885,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Value":167,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":451,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":22,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":450,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":9,"Index":9,"Attempt":0,"Launch Time":1479335616693,"Executor ID":"2","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617329,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":450,"Value":2335,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":22,"Value":189,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":450,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":22,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":444,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":10,"Index":10,"Attempt":0,"Launch Time":1479335616694,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617329,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":444,"Value":2779,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Value":218,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":444,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":29,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":442,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":2,"Index":2,"Attempt":0,"Launch Time":1479335616688,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617329,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":442,"Value":3221,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Value":247,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":442,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":29,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":16,"Index":2,"Attempt":1,"Launch Time":1479335617332,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617371,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":14,"Value":1903,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5136000,"Value":346556000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3673,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":958000,"Value":32856000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":9159,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":17,"Index":10,"Attempt":1,"Launch Time":1479335617333,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617370,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":10,"Value":1889,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3808000,"Value":341420000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3672,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1005000,"Value":31898000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":8196,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":18,"Index":9,"Attempt":1,"Launch Time":1479335617333,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617369,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":7,"Value":1879,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3737000,"Value":337612000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3670,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1066000,"Value":30893000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":7233,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":19,"Index":13,"Attempt":1,"Launch Time":1479335617334,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617368,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":1872,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3918000,"Value":333875000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3668,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":987000,"Value":29827000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":6270,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"ExceptionFailure","Class Name":"java.lang.RuntimeException","Description":"bad exec","Stack Trace":[{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply$mcII$sp","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1","Method Name":"apply","File Name":"","Line Number":26},{"Declaring Class":"scala.collection.Iterator$$anon$11","Method Name":"next","File Name":"Iterator.scala","Line Number":409},{"Declaring Class":"org.apache.spark.util.Utils$","Method Name":"getIteratorSize","File Name":"Utils.scala","Line Number":1757},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.rdd.RDD$$anonfun$count$1","Method Name":"apply","File Name":"RDD.scala","Line Number":1135},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.SparkContext$$anonfun$runJob$5","Method Name":"apply","File Name":"SparkContext.scala","Line Number":1927},{"Declaring Class":"org.apache.spark.scheduler.ResultTask","Method Name":"runTask","File Name":"ResultTask.scala","Line Number":87},{"Declaring Class":"org.apache.spark.scheduler.Task","Method Name":"run","File Name":"Task.scala","Line Number":99},{"Declaring Class":"org.apache.spark.executor.Executor$TaskRunner","Method Name":"run","File Name":"Executor.scala","Line Number":282},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor","Method Name":"runWorker","File Name":"ThreadPoolExecutor.java","Line Number":1142},{"Declaring Class":"java.util.concurrent.ThreadPoolExecutor$Worker","Method Name":"run","File Name":"ThreadPoolExecutor.java","Line Number":617},{"Declaring Class":"java.lang.Thread","Method Name":"run","File Name":"Thread.java","Line Number":745}],"Full Stack Trace":"java.lang.RuntimeException: bad exec\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply$mcII$sp(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat $line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$anonfun$1.apply(:26)\n\tat scala.collection.Iterator$$anon$11.next(Iterator.scala:409)\n\tat org.apache.spark.util.Utils$.getIteratorSize(Utils.scala:1757)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.rdd.RDD$$anonfun$count$1.apply(RDD.scala:1135)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1927)\n\tat org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)\n\tat org.apache.spark.scheduler.Task.run(Task.scala:99)\n\tat org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)\n\tat java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)\n\tat java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)\n\tat java.lang.Thread.run(Thread.java:745)\n","Accumulator Updates":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":446,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":0,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Info":{"Task ID":6,"Index":6,"Attempt":0,"Launch Time":1479335616691,"Executor ID":"0","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617336,"Failed":true,"Killed":false,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Update":446,"Value":3667,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":29,"Value":276,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":0,"Executor Deserialize CPU Time":0,"Executor Run Time":446,"Executor CPU Time":0,"Result Size":0,"JVM GC Time":29,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":20,"Index":6,"Attempt":1,"Launch Time":1479335617349,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617371,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":1907,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3503000,"Value":350059000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3674,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1042000,"Value":33898000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":10122,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":19,"Index":13,"Attempt":1,"Launch Time":1479335617334,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617368,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":5,"Value":1872,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3918000,"Value":333875000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3668,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":987000,"Value":29827000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":6270,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":5,"Executor Deserialize CPU Time":3918000,"Executor Run Time":1,"Executor CPU Time":987000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":21,"Index":1,"Attempt":1,"Launch Time":1479335617368,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617379,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":1911,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3579000,"Value":353638000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3675,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":996000,"Value":34894000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":11085,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":18,"Index":9,"Attempt":1,"Launch Time":1479335617333,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617369,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":7,"Value":1879,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3737000,"Value":337612000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3670,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1066000,"Value":30893000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":7233,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":7,"Executor Deserialize CPU Time":3737000,"Executor Run Time":2,"Executor CPU Time":1066000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":22,"Index":14,"Attempt":1,"Launch Time":1479335617369,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617380,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":1915,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3412000,"Value":357050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3676,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1014000,"Value":35908000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":12048,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":17,"Index":10,"Attempt":1,"Launch Time":1479335617333,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617370,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":10,"Value":1889,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3808000,"Value":341420000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3672,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1005000,"Value":31898000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":8196,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":10,"Executor Deserialize CPU Time":3808000,"Executor Run Time":2,"Executor CPU Time":1005000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskStart","Stage ID":0,"Stage Attempt ID":0,"Task Info":{"Task ID":23,"Index":5,"Attempt":1,"Launch Time":1479335617370,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617380,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":1918,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3482000,"Value":360532000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3678,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1142000,"Value":37050000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":13011,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":16,"Index":2,"Attempt":1,"Launch Time":1479335617332,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617371,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":14,"Value":1903,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":5136000,"Value":346556000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3673,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":958000,"Value":32856000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":9159,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":14,"Executor Deserialize CPU Time":5136000,"Executor Run Time":1,"Executor CPU Time":958000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":20,"Index":6,"Attempt":1,"Launch Time":1479335617349,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617371,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":1907,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3503000,"Value":350059000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3674,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1042000,"Value":33898000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":10122,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3503000,"Executor Run Time":1,"Executor CPU Time":1042000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":21,"Index":1,"Attempt":1,"Launch Time":1479335617368,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617379,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":1911,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3579000,"Value":353638000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3675,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":996000,"Value":34894000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":11085,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3579000,"Executor Run Time":1,"Executor CPU Time":996000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":22,"Index":14,"Attempt":1,"Launch Time":1479335617369,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617380,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":4,"Value":1915,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3412000,"Value":357050000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":1,"Value":3676,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1014000,"Value":35908000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":12048,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":4,"Executor Deserialize CPU Time":3412000,"Executor Run Time":1,"Executor CPU Time":1014000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":23,"Index":5,"Attempt":1,"Launch Time":1479335617370,"Executor ID":"3","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617380,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":3,"Value":1918,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":3482000,"Value":360532000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":2,"Value":3678,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":1142000,"Value":37050000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":963,"Value":13011,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":3,"Executor Deserialize CPU Time":3482000,"Executor Run Time":2,"Executor CPU Time":1142000,"Result Size":963,"JVM GC Time":0,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":0,"Index":0,"Attempt":0,"Launch Time":1479335616657,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617470,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":714,"Value":2632,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":41300000,"Value":401832000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":18,"Value":3696,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":6640000,"Value":43690000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1036,"Value":14047,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":17,"Value":293,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":714,"Executor Deserialize CPU Time":41300000,"Executor Run Time":18,"Executor CPU Time":6640000,"Result Size":1036,"JVM GC Time":17,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":8,"Index":8,"Attempt":0,"Launch Time":1479335616692,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617471,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":714,"Value":3346,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":43682000,"Value":445514000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":15,"Value":3711,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":9441000,"Value":53131000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1036,"Value":15083,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":17,"Value":310,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":714,"Executor Deserialize CPU Time":43682000,"Executor Run Time":15,"Executor CPU Time":9441000,"Result Size":1036,"JVM GC Time":17,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":12,"Index":12,"Attempt":0,"Launch Time":1479335616695,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617471,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":691,"Value":4037,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":54811000,"Value":500325000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":3727,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":4571000,"Value":57702000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1036,"Value":16119,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":17,"Value":327,"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":691,"Executor Deserialize CPU Time":54811000,"Executor Run Time":16,"Executor CPU Time":4571000,"Result Size":1036,"JVM GC Time":17,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[]}} +{"Event":"SparkListenerTaskEnd","Stage ID":0,"Stage Attempt ID":0,"Task Type":"ResultTask","Task End Reason":{"Reason":"Success"},"Task Info":{"Task ID":4,"Index":4,"Attempt":0,"Launch Time":1479335616689,"Executor ID":"1","Host":"172.22.0.167","Locality":"PROCESS_LOCAL","Speculative":false,"Getting Result Time":0,"Finish Time":1479335617473,"Failed":false,"Killed":false,"Accumulables":[{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Update":716,"Value":4753,"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Update":220235000,"Value":720560000,"Internal":true,"Count Failed Values":true},{"ID":2,"Name":"internal.metrics.executorRunTime","Update":16,"Value":3743,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Update":5849000,"Value":63551000,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Update":1851,"Value":17970,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Update":17,"Value":344,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Update":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true}]},"Task Metrics":{"Executor Deserialize Time":716,"Executor Deserialize CPU Time":220235000,"Executor Run Time":16,"Executor CPU Time":5849000,"Result Size":1851,"JVM GC Time":17,"Result Serialization Time":0,"Memory Bytes Spilled":0,"Disk Bytes Spilled":0,"Shuffle Read Metrics":{"Remote Blocks Fetched":0,"Local Blocks Fetched":0,"Fetch Wait Time":0,"Remote Bytes Read":0,"Local Bytes Read":0,"Total Records Read":0},"Shuffle Write Metrics":{"Shuffle Bytes Written":0,"Shuffle Write Time":0,"Shuffle Records Written":0},"Input Metrics":{"Bytes Read":0,"Records Read":0},"Output Metrics":{"Bytes Written":0,"Records Written":0},"Updated Blocks":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}]}} +{"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":0,"Stage Attempt ID":0,"Stage Name":"count at :26","Number of Tasks":16,"RDD Info":[{"RDD ID":1,"Name":"MapPartitionsRDD","Scope":"{\"id\":\"1\",\"name\":\"map\"}","Callsite":"map at :26","Parent IDs":[0],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0},{"RDD ID":0,"Name":"ParallelCollectionRDD","Scope":"{\"id\":\"0\",\"name\":\"parallelize\"}","Callsite":"parallelize at :26","Parent IDs":[],"Storage Level":{"Use Disk":false,"Use Memory":false,"Deserialized":false,"Replication":1},"Number of Partitions":16,"Number of Cached Partitions":0,"Memory Size":0,"Disk Size":0}],"Parent IDs":[],"Details":"org.apache.spark.rdd.RDD.count(RDD.scala:1135)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:26)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw.(:31)\n$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw.(:33)\n$line16.$read$$iw$$iw$$iw$$iw$$iw.(:35)\n$line16.$read$$iw$$iw$$iw$$iw.(:37)\n$line16.$read$$iw$$iw$$iw.(:39)\n$line16.$read$$iw$$iw.(:41)\n$line16.$read$$iw.(:43)\n$line16.$read.(:45)\n$line16.$read$.(:49)\n$line16.$read$.()\n$line16.$eval$.$print$lzycompute(:7)\n$line16.$eval$.$print(:6)\n$line16.$eval.$print()\nsun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)\nsun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)\nsun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)\njava.lang.reflect.Method.invoke(Method.java:498)\nscala.tools.nsc.interpreter.IMain$ReadEvalPrint.call(IMain.scala:786)","Submission Time":1479335616653,"Completion Time":1479335617476,"Accumulables":[{"ID":2,"Name":"internal.metrics.executorRunTime","Value":3743,"Internal":true,"Count Failed Values":true},{"ID":5,"Name":"internal.metrics.jvmGCTime","Value":344,"Internal":true,"Count Failed Values":true},{"ID":4,"Name":"internal.metrics.resultSize","Value":17970,"Internal":true,"Count Failed Values":true},{"ID":10,"Name":"internal.metrics.updatedBlockStatuses","Value":[{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}},{"Block ID":"broadcast_0_piece0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":false,"Replication":1},"Memory Size":1150,"Disk Size":0}},{"Block ID":"broadcast_0","Status":{"Storage Level":{"Use Disk":false,"Use Memory":true,"Deserialized":true,"Replication":1},"Memory Size":1736,"Disk Size":0}}],"Internal":true,"Count Failed Values":true},{"ID":1,"Name":"internal.metrics.executorDeserializeCpuTime","Value":720560000,"Internal":true,"Count Failed Values":true},{"ID":3,"Name":"internal.metrics.executorCpuTime","Value":63551000,"Internal":true,"Count Failed Values":true},{"ID":6,"Name":"internal.metrics.resultSerializationTime","Value":4,"Internal":true,"Count Failed Values":true},{"ID":0,"Name":"internal.metrics.executorDeserializeTime","Value":4753,"Internal":true,"Count Failed Values":true}]}} +{"Event":"SparkListenerJobEnd","Job ID":0,"Completion Time":1479335617480,"Job Result":{"Result":"JobSucceeded"}} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479335617478,"executorId":"2","taskFailures":4} +{"Event":"org.apache.spark.scheduler.SparkListenerExecutorBlacklisted","time":1479335617478,"executorId":"0","taskFailures":4} +{"Event":"org.apache.spark.scheduler.SparkListenerNodeBlacklisted","time":1479335617478,"hostId":"172.22.0.167","executorFailures":2} +{"Event":"SparkListenerApplicationEnd","Timestamp":1479335620587} diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 6d03ee091e4ed..ddbcb2d19dcbb 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -243,7 +243,7 @@ private[spark] object AccumulatorSuite { import InternalAccumulator._ /** - * Create a long accumulator and register it to [[AccumulatorContext]]. + * Create a long accumulator and register it to `AccumulatorContext`. */ def createLongAccum( name: String, @@ -258,7 +258,7 @@ private[spark] object AccumulatorSuite { } /** - * Make an [[AccumulableInfo]] out of an [[Accumulable]] with the intent to use the + * Make an `AccumulableInfo` out of an [[Accumulable]] with the intent to use the * info as an accumulator update. */ def makeInfo(a: AccumulatorV2[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b117c7709b46f..48408ccc8f81b 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,8 +21,10 @@ import java.io.File import scala.reflect.ClassTag +import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -112,7 +114,7 @@ trait RDDCheckpointTester { self: SparkFunSuite => * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, * the generated RDD will remember the partitions and therefore potentially the whole lineage. * This function should be called only those RDD whose partitions refer to parent RDD's - * partitions (i.e., do not call it on simple RDD like MappedRDD). + * partitions (i.e., do not call it on simple RDDs). * * @param op an operation to run on the RDD * @param reliableCheckpoint if true, use reliable checkpoints, otherwise use local checkpoints @@ -386,7 +388,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS // the parent RDD has been checkpointed and parent partitions have been changed. // Note that this test is very specific to the current implementation of CartesianRDD. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) val cartesian = new CartesianRDD(sc, ones, ones) val splitBeforeCheckpoint = serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition]) @@ -409,7 +411,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS // Note that this test is very specific to the current implementation of // CoalescedRDDPartitions. val ones = sc.makeRDD(1 to 100, 10).map(x => x) - checkpoint(ones, reliableCheckpoint) // checkpoint that MappedRDD + checkpoint(ones, reliableCheckpoint) val coalesced = new CoalescedRDD(ones, 2) val splitBeforeCheckpoint = serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition]) @@ -580,3 +582,42 @@ object CheckpointSuite { ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } + +class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { + + test("checkpoint compression") { + val checkpointDir = Utils.createTempDir() + try { + val conf = new SparkConf() + .set("spark.checkpoint.compress", "true") + .set("spark.ui.enabled", "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + + val checkpointPath = new Path(rdd.getCheckpointFile.get) + val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration) + val checkpointFile = + fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get + + // Verify the checkpoint file is compressed, in other words, can be decompressed + val compressedInputStream = CompressionCodec.createCodec(conf) + .compressedInputStream(fs.open(checkpointFile)) + try { + ByteStreams.toByteArray(compressedInputStream) + } finally { + compressedInputStream.close() + } + + // Verify that the compressed content can be read back + assert(rdd.collect().toSeq === (1 to 20)) + } finally { + Utils.deleteRecursively(checkpointDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala index fb8d701ebda8a..91355f7362900 100644 --- a/core/src/test/scala/org/apache/spark/DebugFilesystem.scala +++ b/core/src/test/scala/org/apache/spark/DebugFilesystem.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io.{FileDescriptor, InputStream} import java.lang import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -31,20 +30,29 @@ import org.apache.spark.internal.Logging object DebugFilesystem extends Logging { // Stores the set of active streams and their creation sites. - private val openStreams = new ConcurrentHashMap[FSDataInputStream, Throwable]() + private val openStreams = mutable.Map.empty[FSDataInputStream, Throwable] - def clearOpenStreams(): Unit = { + def addOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.put(stream, new Throwable()) + } + + def clearOpenStreams(): Unit = openStreams.synchronized { openStreams.clear() } - def assertNoOpenStreams(): Unit = { - val numOpen = openStreams.size() + def removeOpenStream(stream: FSDataInputStream): Unit = openStreams.synchronized { + openStreams.remove(stream) + } + + def assertNoOpenStreams(): Unit = openStreams.synchronized { + val numOpen = openStreams.values.size if (numOpen > 0) { - for (exc <- openStreams.values().asScala) { + for (exc <- openStreams.values) { logWarning("Leaked filesystem connection created at:") exc.printStackTrace() } - throw new RuntimeException(s"There are $numOpen possibly leaked file streams.") + throw new IllegalStateException(s"There are $numOpen possibly leaked file streams.", + openStreams.values.head) } } } @@ -59,8 +67,7 @@ class DebugFilesystem extends LocalFileSystem { override def open(f: Path, bufferSize: Int): FSDataInputStream = { val wrapped: FSDataInputStream = super.open(f, bufferSize) - openStreams.put(wrapped, new Throwable()) - + addOpenStream(wrapped) new FSDataInputStream(wrapped.getWrappedStream) { override def setDropBehind(dropBehind: lang.Boolean): Unit = wrapped.setDropBehind(dropBehind) @@ -97,7 +104,7 @@ class DebugFilesystem extends LocalFileSystem { override def close(): Unit = { wrapped.close() - openStreams.remove(wrapped) + removeOpenStream(wrapped) } override def read(): Int = wrapped.read() diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 4e36adc8baf3f..84f7f1fc8eb09 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.io.ChunkedByteBuffer @@ -28,7 +29,8 @@ class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} -class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext { +class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContext + with EncryptionFunSuite { val clusterUrl = "local-cluster[2,1,1024]" @@ -149,8 +151,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex sc.parallelize(1 to 10).count() } - private def testCaching(storageLevel: StorageLevel): Unit = { - sc = new SparkContext(clusterUrl, "test") + private def testCaching(conf: SparkConf, storageLevel: StorageLevel): Unit = { + sc = new SparkContext(conf.setMaster(clusterUrl).setAppName("test")) sc.jobProgressListener.waitUntilExecutorsUp(2, 30000) val data = sc.parallelize(1 to 1000, 10) val cachedData = data.persist(storageLevel) @@ -187,8 +189,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2, "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2 ).foreach { case (testName, storageLevel) => - test(testName) { - testCaching(storageLevel) + encryptionTest(testName) { conf => + testCaching(conf, storageLevel) } } diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index ec409712b953c..4ea42fc7d5c22 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -1138,7 +1138,10 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def requestExecutors(numAdditionalExecutors: Int): Boolean = sc.requestExecutors(numAdditionalExecutors) - override def killExecutors(executorIds: Seq[String]): Seq[String] = { + override def killExecutors( + executorIds: Seq[String], + replace: Boolean, + force: Boolean): Seq[String] = { val response = sc.killExecutors(executorIds) if (response) { executorIds @@ -1154,4 +1157,8 @@ private class DummyLocalSchedulerBackend (sc: SparkContext, sb: SchedulerBackend override def reviveOffers(): Unit = sb.reviveOffers() override def defaultParallelism(): Int = sb.defaultParallelism() + + override def killExecutorsOnHost(host: String): Boolean = { + false + } } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index eb3fb99747d12..fe944031bc948 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalSh /** * This suite creates an external shuffle server and routes all shuffle fetches through it. * Note that failures in this suite may arise due to changes in Spark that invalidate expectations - * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how + * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how * we hash files into folders. */ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index cc52bb1d23cd5..5be0121db58ae 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark import java.io._ +import java.nio.ByteBuffer import java.util.zip.GZIPOutputStream import scala.io.Source +import org.apache.hadoop.fs.Path import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec import org.apache.hadoop.mapred.{FileAlreadyExistsException, FileSplit, JobConf, TextInputFormat, TextOutputFormat} @@ -29,7 +31,6 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel @@ -58,10 +59,15 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { nums.saveAsTextFile(outputDir) // Read the plain text file and check it's OK val outputFile = new File(outputDir, "part-00000") - val content = Source.fromFile(outputFile).mkString - assert(content === "1\n2\n3\n4\n") - // Also try reading it in as a text file RDD - assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + val bufferSrc = Source.fromFile(outputFile) + Utils.tryWithSafeFinally { + val content = bufferSrc.mkString + assert(content === "1\n2\n3\n4\n") + // Also try reading it in as a text file RDD + assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4")) + } { + bufferSrc.close() + } } test("text files (compressed)") { @@ -231,184 +237,82 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) } - test("binary file input as byte array") { - sc = new SparkContext("local", "test") + private def writeBinaryData(testOutput: Array[Byte], testOutputCopies: Int): File = { val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file - val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) + val file = new FileOutputStream(outFile) val channel = file.getChannel - channel.write(bbuf) + for (i <- 0 until testOutputCopies) { + // Shift values by i so that they're different in the output + val alteredOutput = testOutput.map(b => (b + i).toByte) + channel.write(ByteBuffer.wrap(alteredOutput)) + } channel.close() file.close() + outFile + } - val inRdd = sc.binaryFiles(outFileName) - val (infile: String, indata: PortableDataStream) = inRdd.collect.head - + test("binary file input as byte array") { + sc = new SparkContext("local", "test") + val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath) + val (infile, indata) = inRdd.collect().head // Make sure the name and array match - assert(infile.contains(outFileName)) // a prefix may get added + assert(infile.contains(outFile.toURI.getPath)) // a prefix may get added assert(indata.toArray === testOutput) } test("portabledatastream caching tests") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName).cache() - inRdd.foreach{ - curData: (String, PortableDataStream) => - curData._2.toArray() // force the file to read - } - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head - + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath).cache() + inRdd.foreach(_._2.toArray()) // force the file to read // Try reading the output back as an object file - - assert(indata.toArray === testOutput) + assert(inRdd.values.collect().head.toArray === testOutput) } test("portabledatastream persist disk storage") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY) - inRdd.foreach{ - curData: (String, PortableDataStream) => - curData._2.toArray() // force the file to read - } - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head - - // Try reading the output back as an object file - - assert(indata.toArray === testOutput) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath).persist(StorageLevel.DISK_ONLY) + inRdd.foreach(_._2.toArray()) // force the file to read + assert(inRdd.values.collect().head.toArray === testOutput) } test("portabledatastream flatmap tests") { sc = new SparkContext("local", "test") - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) + val outFile = writeBinaryData(testOutput, 1) + val inRdd = sc.binaryFiles(outFile.getAbsolutePath) val numOfCopies = 3 - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - channel.write(bbuf) - channel.close() - file.close() - - val inRdd = sc.binaryFiles(outFileName) - val mappedRdd = inRdd.map { - curData: (String, PortableDataStream) => - (curData._2.getPath(), curData._2) - } - val copyRdd = mappedRdd.flatMap { - curData: (String, PortableDataStream) => - for (i <- 1 to numOfCopies) yield (i, curData._2) - } - - val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect() - - // Try reading the output back as an object file + val copyRdd = inRdd.flatMap(curData => (0 until numOfCopies).map(_ => curData._2)) + val copyArr = copyRdd.collect() assert(copyArr.length == numOfCopies) - copyArr.foreach{ - cEntry: (Int, PortableDataStream) => - assert(cEntry._2.toArray === testOutput) + for (i <- copyArr.indices) { + assert(copyArr(i).toArray === testOutput) } - } test("fixed record length binary file as byte array") { - // a fixed length of 6 bytes - sc = new SparkContext("local", "test") - - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) val testOutputCopies = 10 - - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - for(i <- 1 to testOutputCopies) { - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - channel.write(bbuf) - } - channel.close() - file.close() - - val inRdd = sc.binaryRecords(outFileName, testOutput.length) - // make sure there are enough elements + val outFile = writeBinaryData(testOutput, testOutputCopies) + val inRdd = sc.binaryRecords(outFile.getAbsolutePath, testOutput.length) assert(inRdd.count == testOutputCopies) - - // now just compare the first one - val indata: Array[Byte] = inRdd.collect.head - assert(indata === testOutput) + val inArr = inRdd.collect() + for (i <- inArr.indices) { + assert(inArr(i) === testOutput.map(b => (b + i).toByte)) + } } test ("negative binary record length should raise an exception") { - // a fixed length of 6 bytes sc = new SparkContext("local", "test") - - val outFile = new File(tempDir, "record-bytestream-00000.bin") - val outFileName = outFile.getAbsolutePath() - - // create file - val testOutput = Array[Byte](1, 2, 3, 4, 5, 6) - val testOutputCopies = 10 - - // write data to file - val file = new java.io.FileOutputStream(outFile) - val channel = file.getChannel - for(i <- 1 to testOutputCopies) { - val bbuf = java.nio.ByteBuffer.wrap(testOutput) - channel.write(bbuf) - } - channel.close() - file.close() - - val inRdd = sc.binaryRecords(outFileName, -1) - + val outFile = writeBinaryData(Array[Byte](1, 2, 3, 4, 5, 6), 1) intercept[SparkException] { - inRdd.count + sc.binaryRecords(outFile.getAbsolutePath, -1).count() } } @@ -497,7 +401,7 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.set("mapred.output.format.class", classOf[TextOutputFormat[String, String]].getName) - job.set("mapred.output.dir", tempDir.getPath + "/outputDataset_old") + job.set("mapreduce.output.fileoutputformat.outputdir", tempDir.getPath + "/outputDataset_old") randomRDD.saveAsHadoopDataset(job) assert(new File(tempDir.getPath + "/outputDataset_old/part-00000").exists() === true) } @@ -511,7 +415,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) val jobConfig = job.getConfiguration - jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") + jobConfig.set("mapreduce.output.fileoutputformat.outputdir", + tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) } @@ -527,7 +432,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { .mapPartitionsWithInputSplit { (split, part) => Iterator(split.asInstanceOf[FileSplit].getPath.toUri.getPath) }.collect() - assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + val outPathOne = new Path(outDir, "part-00000").toUri.getPath + val outPathTwo = new Path(outDir, "part-00001").toUri.getPath + assert(inputPaths.toSet === Set(outPathOne, outPathTwo)) } test("Get input files via new Hadoop API") { @@ -541,7 +448,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { .mapPartitionsWithInputSplit { (split, part) => Iterator(split.asInstanceOf[NewFileSplit].getPath.toUri.getPath) }.collect() - assert(inputPaths.toSet === Set(s"$outDir/part-00000", s"$outDir/part-00001")) + val outPathOne = new Path(outDir, "part-00000").toUri.getPath + val outPathTwo = new Path(outDir, "part-00001").toUri.getPath + assert(inputPaths.toSet === Set(outPathOne, outPathTwo)) } test("spark.files.ignoreCorruptFiles should work both HadoopRDD and NewHadoopRDD") { diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 915d7a1b8b164..88916488c0def 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -46,8 +46,8 @@ class HeartbeatReceiverSuite with PrivateMethodTester with LocalSparkContext { - private val executorId1 = "executor-1" - private val executorId2 = "executor-2" + private val executorId1 = "1" + private val executorId2 = "2" // Shared state that must be reset before and after each test private var scheduler: TaskSchedulerImpl = null @@ -93,12 +93,12 @@ class HeartbeatReceiverSuite test("task scheduler is set correctly") { assert(heartbeatReceiver.scheduler === null) - heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiverRef.askSync[Boolean](TaskSchedulerIsSet) assert(heartbeatReceiver.scheduler !== null) } test("normal heartbeat") { - heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiverRef.askSync[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) @@ -116,14 +116,14 @@ class HeartbeatReceiverSuite } test("reregister if heartbeat from unregistered executor") { - heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiverRef.askSync[Boolean](TaskSchedulerIsSet) // Received heartbeat from unknown executor, so we ask it to re-register triggerHeartbeat(executorId1, executorShouldReregister = true) assert(getTrackedExecutors.isEmpty) } test("reregister if heartbeat from removed executor") { - heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiverRef.askSync[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) // Remove the second executor but not the first @@ -140,7 +140,7 @@ class HeartbeatReceiverSuite test("expire dead hosts") { val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) - heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiverRef.askSync[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) @@ -149,7 +149,7 @@ class HeartbeatReceiverSuite heartbeatReceiverClock.advance(executorTimeout / 2) triggerHeartbeat(executorId1, executorShouldReregister = false) heartbeatReceiverClock.advance(executorTimeout) - heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + heartbeatReceiverRef.askSync[Boolean](ExpireDeadHosts) // Only the second executor should be expired as a dead host verify(scheduler).executorLost(Matchers.eq(executorId2), any()) val trackedExecutors = getTrackedExecutors @@ -173,11 +173,11 @@ class HeartbeatReceiverSuite val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) - fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( + fakeSchedulerBackend.driverEndpoint.askSync[Boolean]( RegisterExecutor(executorId1, dummyExecutorEndpointRef1, "1.2.3.4", 0, Map.empty)) - fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( + fakeSchedulerBackend.driverEndpoint.askSync[Boolean]( RegisterExecutor(executorId2, dummyExecutorEndpointRef2, "1.2.3.5", 0, Map.empty)) - heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) + heartbeatReceiverRef.askSync[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) addExecutorAndVerify(executorId2) triggerHeartbeat(executorId1, executorShouldReregister = false) @@ -195,7 +195,7 @@ class HeartbeatReceiverSuite // Here we use a timeout of O(seconds), but in practice this whole test takes O(10ms). val executorTimeout = heartbeatReceiver.invokePrivate(_executorTimeoutMs()) heartbeatReceiverClock.advance(executorTimeout * 2) - heartbeatReceiverRef.askWithRetry[Boolean](ExpireDeadHosts) + heartbeatReceiverRef.askSync[Boolean](ExpireDeadHosts) val killThread = heartbeatReceiver.invokePrivate(_killExecutorThread()) killThread.shutdown() // needed for awaitTermination killThread.awaitTermination(10L, TimeUnit.SECONDS) @@ -213,7 +213,7 @@ class HeartbeatReceiverSuite executorShouldReregister: Boolean): Unit = { val metrics = TaskMetrics.empty val blockManagerId = BlockManagerId(executorId, "localhost", 12345) - val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + val response = heartbeatReceiverRef.askSync[HeartbeatResponse]( Heartbeat(executorId, Array(1L -> metrics.accumulators()), blockManagerId)) if (executorShouldReregister) { assert(response.reregisterBlockManager) @@ -272,7 +272,7 @@ private class FakeSchedulerBackend( protected override def doRequestTotalExecutors(requestedTotal: Int): Future[Boolean] = { clusterManagerEndpoint.ask[Boolean]( - RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount)) + RequestExecutors(requestedTotal, localityAwareTasks, hostToLocalTaskCount, Set.empty[String])) } protected override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = { @@ -291,7 +291,7 @@ private class FakeClusterManager(override val rpcEnv: RpcEnv) extends RpcEndpoin def getExecutorIdsToKill: Set[String] = executorIdsToKill.toSet override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestExecutors(requestedTotal, _, _) => + case RequestExecutors(requestedTotal, _, _, _) => targetNumExecutors = requestedTotal context.reply(true) case KillExecutors(executorIds) => diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 840f55ce2f6e5..8d7be77f51fe9 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index a3490fc79e458..99150a1430d95 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -209,6 +209,83 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft assert(jobB.get() === 100) } + test("task reaper kills JVM if killed tasks keep running for too long") { + val conf = new SparkConf() + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "5s") + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = Future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 10000, 2).map { i => + while (true) { } + }.count() + } + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + // Small delay to ensure tasks actually start executing the task body + Thread.sleep(1000) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100) + } + + test("task reaper will not kill JVM if spark.task.killTimeout == -1") { + val conf = new SparkConf() + .set("spark.task.reaper.enabled", "true") + .set("spark.task.reaper.killTimeout", "-1") + .set("spark.task.reaper.PollingInterval", "1s") + .set("spark.deploy.maxExecutorRetries", "1") + sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + + // Add a listener to release the semaphore once any tasks are launched. + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart) { + sem.release() + } + }) + + // jobA is the one to be cancelled. + val jobA = Future { + sc.setJobGroup("jobA", "this is a job to be cancelled", interruptOnCancel = true) + sc.parallelize(1 to 2, 2).map { i => + val startTime = System.currentTimeMillis() + while (System.currentTimeMillis() < startTime + 10000) { } + }.count() + } + + // Block until both tasks of job A have started and cancel job A. + sem.acquire(2) + // Small delay to ensure tasks actually start executing the task body + Thread.sleep(1000) + + sc.clearJobGroup() + val jobB = sc.parallelize(1 to 100, 2).countAsync() + sc.cancelJobGroup("jobA") + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 15.seconds) }.getCause + assert(e.getMessage contains "cancel") + + // Once A is cancelled, job B should finish fairly quickly. + assert(ThreadUtils.awaitResult(jobB, 60.seconds) === 100) + } + test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // twoJobsSharingStageSemaphore: diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 24ec99c7e5e60..1dd89bcbe36bc 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite -/** Manages a local `sc` {@link SparkContext} variable, correctly stopping it after each test. */ +/** Manages a local `sc` `SparkContext` variable, correctly stopping it after each test. */ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 2b8b1805bc83f..6fc7cea6ee94a 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -103,6 +103,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.ui.enabled", "false") + conf.set("spark.ssl.ui.port", "4242") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") conf.set("spark.ssl.ui.keyStorePassword", "12345") @@ -118,6 +119,7 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val opts = SSLOptions.parse(conf, "spark.ssl.ui", defaults = Some(defaultOpts)) assert(opts.enabled === false) + assert(opts.port === Some(4242)) assert(opts.trustStore.isDefined === true) assert(opts.trustStore.get.getName === "truststore") assert(opts.trustStore.get.getAbsolutePath === trustStorePath) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index a854f5bb9b7ce..58b865969f517 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.util.Properties +import java.util.{Locale, Properties} import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers @@ -29,7 +29,7 @@ import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListene import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} -import org.apache.spark.util.MutablePair +import org.apache.spark.util.{MutablePair, Utils} abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -239,7 +239,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.toLowerCase.contains("serializable")) + assert(thrown.getMessage.toLowerCase(Locale.ROOT).contains("serializable")) } test("shuffle with different compression settings (SPARK-3426)") { diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index 83906cff123bf..0897891ee1758 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -303,6 +303,25 @@ class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSyst } } + test("encryption requires authentication") { + val conf = new SparkConf() + conf.validateSettings() + + conf.set(NETWORK_ENCRYPTION_ENABLED, true) + intercept[IllegalArgumentException] { + conf.validateSettings() + } + + conf.set(NETWORK_ENCRYPTION_ENABLED, false) + conf.set(SASL_ENCRYPTION_ENABLED, true) + intercept[IllegalArgumentException] { + conf.validateSettings() + } + + conf.set(NETWORK_AUTH_ENABLED, true) + conf.validateSettings() + } + } class Class1 {} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index c451c596b069a..7e26139a2bead 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -18,22 +18,27 @@ package org.apache.spark import java.io.File -import java.net.MalformedURLException +import java.net.{MalformedURLException, URI} import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit +import scala.concurrent.duration._ import scala.concurrent.Await -import scala.concurrent.duration.Duration import com.google.common.io.Files +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.scalatest.concurrent.Eventually import org.scalatest.Matchers._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart} import org.apache.spark.util.Utils -class SparkContextSuite extends SparkFunSuite with LocalSparkContext { + +class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually { test("Only one SparkContext may be active at a time") { // Regression test for SPARK-4180 @@ -289,6 +294,22 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("add jar with invalid path") { + val tmpDir = Utils.createTempDir() + val tmpJar = File.createTempFile("test", ".jar", tmpDir) + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(tmpJar.getAbsolutePath) + + // Invaid jar path will only print the error log, will not add to file server. + sc.addJar("dummy.jar") + sc.addJar("") + sc.addJar(tmpDir.getAbsolutePath) + + sc.listJars().size should be (1) + sc.listJars().head should include (tmpJar.getName) + } + test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") { try { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) @@ -451,4 +472,151 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("register and deregister Spark listener from SparkContext") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val sparkListener1 = new SparkListener { } + val sparkListener2 = new SparkListener { } + sc.addSparkListener(sparkListener1) + sc.addSparkListener(sparkListener2) + assert(sc.listenerBus.listeners.contains(sparkListener1)) + assert(sc.listenerBus.listeners.contains(sparkListener2)) + sc.removeSparkListener(sparkListener1) + assert(!sc.listenerBus.listeners.contains(sparkListener1)) + assert(sc.listenerBus.listeners.contains(sparkListener2)) + } + + test("Cancelling stages/jobs with custom reasons.") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val REASON = "You shall not pass" + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + if (SparkContextSuite.cancelStage) { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + sc.cancelStage(taskStart.stageId, REASON) + SparkContextSuite.cancelStage = false + } + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + if (SparkContextSuite.cancelJob) { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + sc.cancelJob(jobStart.jobId, REASON) + SparkContextSuite.cancelJob = false + } + } + } + sc.addSparkListener(listener) + + for (cancelWhat <- Seq("stage", "job")) { + SparkContextSuite.isTaskStarted = false + SparkContextSuite.cancelStage = (cancelWhat == "stage") + SparkContextSuite.cancelJob = (cancelWhat == "job") + + val ex = intercept[SparkException] { + sc.range(0, 10000L).mapPartitions { x => + org.apache.spark.SparkContextSuite.isTaskStarted = true + x + }.cartesian(sc.range(0, 10L))count() + } + + ex.getCause() match { + case null => + assert(ex.getMessage().contains(REASON)) + case cause: SparkException => + assert(cause.getMessage().contains(REASON)) + case cause: Throwable => + fail("Expected the cause to be SparkException, got " + cause.toString() + " instead.") + } + + eventually(timeout(20.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + } + + testCancellingTasks("that raise interrupted exception on cancel") { + Thread.sleep(9999999) + } + + // SPARK-20217 should not fail stage if task throws non-interrupted exception + testCancellingTasks("that raise runtime exception on cancel") { + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + throw new RuntimeException("killed") + } + } + + // Launches one task that will block forever. Once the SparkListener detects the task has + // started, kill and re-schedule it. The second run of the task will complete immediately. + // If this test times out, then the first version of the task wasn't killed successfully. + def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + SparkContextSuite.isTaskStarted = false + SparkContextSuite.taskKilled = false + SparkContextSuite.taskSucceeded = false + + val listener = new SparkListener { + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.isTaskStarted) + } + if (!SparkContextSuite.taskKilled) { + SparkContextSuite.taskKilled = true + sc.killTaskAttempt(taskStart.taskInfo.taskId, true, "first attempt will hang") + } + } + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskInfo.attemptNumber == 1 && taskEnd.reason == Success) { + SparkContextSuite.taskSucceeded = true + } + } + } + sc.addSparkListener(listener) + eventually(timeout(20.seconds)) { + sc.parallelize(1 to 1).foreach { x => + // first attempt will hang + if (!SparkContextSuite.isTaskStarted) { + SparkContextSuite.isTaskStarted = true + blockFn + } + // second attempt succeeds immediately + } + } + eventually(timeout(10.seconds)) { + assert(SparkContextSuite.taskSucceeded) + } + } + + test("SPARK-19446: DebugFilesystem.assertNoOpenStreams should report " + + "open streams to help debugging") { + val fs = new DebugFilesystem() + fs.initialize(new URI("file:///"), new Configuration()) + val file = File.createTempFile("SPARK19446", "temp") + Files.write(Array.ofDim[Byte](1000), file) + val path = new Path("file:///" + file.getCanonicalPath) + val stream = fs.open(path) + val exc = intercept[RuntimeException] { + DebugFilesystem.assertNoOpenStreams() + } + assert(exc != null) + assert(exc.getCause() != null) + stream.close() + } +} + +object SparkContextSuite { + @volatile var cancelJob = false + @volatile var cancelStage = false + @volatile var isTaskStarted = false + @volatile var taskKilled = false + @volatile var taskSucceeded = false } diff --git a/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala b/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala new file mode 100644 index 0000000000000..6a979aefe6e90 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/r/JVMObjectTrackerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import org.apache.spark.SparkFunSuite + +class JVMObjectTrackerSuite extends SparkFunSuite { + test("JVMObjectId does not take null IDs") { + intercept[IllegalArgumentException] { + JVMObjectId(null) + } + } + + test("JVMObjectTracker") { + val tracker = new JVMObjectTracker + assert(tracker.size === 0) + withClue("an empty tracker can be cleared") { + tracker.clear() + } + val none = JVMObjectId("none") + assert(tracker.get(none) === None) + intercept[NoSuchElementException] { + tracker(JVMObjectId("none")) + } + + val obj1 = new Object + val id1 = tracker.addAndGetId(obj1) + assert(id1 != null) + assert(tracker.size === 1) + assert(tracker.get(id1).get.eq(obj1)) + assert(tracker(id1).eq(obj1)) + + val obj2 = new Object + val id2 = tracker.addAndGetId(obj2) + assert(id1 !== id2) + assert(tracker.size === 2) + assert(tracker(id2).eq(obj2)) + + val Some(obj1Removed) = tracker.remove(id1) + assert(obj1Removed.eq(obj1)) + assert(tracker.get(id1) === None) + assert(tracker.size === 1) + assert(tracker(id2).eq(obj2)) + + val obj3 = new Object + val id3 = tracker.addAndGetId(obj3) + assert(tracker.size === 2) + assert(id3 != id1) + assert(id3 != id2) + assert(tracker(id3).eq(obj3)) + + tracker.clear() + assert(tracker.size === 0) + assert(tracker.get(id1) === None) + assert(tracker.get(id2) === None) + assert(tracker.get(id3) === None) + } +} diff --git a/external/java8-tests/src/test/scala/test/org/apache/spark/java8/JDK8ScalaSuite.scala b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala similarity index 72% rename from external/java8-tests/src/test/scala/test/org/apache/spark/java8/JDK8ScalaSuite.scala rename to core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala index c4042e47e84e8..085cc267ca74d 100644 --- a/external/java8-tests/src/test/scala/test/org/apache/spark/java8/JDK8ScalaSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/r/RBackendSuite.scala @@ -15,16 +15,17 @@ * limitations under the License. */ -package test.org.apache.spark.java8 +package org.apache.spark.api.r -import org.apache.spark.SharedSparkContext import org.apache.spark.SparkFunSuite -/** - * Test cases where JDK8-compiled Scala user code is used with Spark. - */ -class JDK8ScalaSuite extends SparkFunSuite with SharedSparkContext { - test("basic RDD closure test (SPARK-6152)") { - sc.parallelize(1 to 1000).map(x => x * x).count() +class RBackendSuite extends SparkFunSuite { + test("close() clears jvmObjectTracker") { + val backend = new RBackend + val tracker = backend.jvmObjectTracker + val id = tracker.addAndGetId(new Object) + backend.close() + assert(tracker.get(id) === None) + assert(tracker.size === 0) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 973676398ae54..46f9ac6b0273a 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.broadcast +import java.util.Locale + import scala.util.Random import org.scalatest.Assertions @@ -24,8 +26,10 @@ import org.scalatest.Assertions import org.apache.spark._ import org.apache.spark.io.SnappyCompressionCodec import org.apache.spark.rdd.RDD +import org.apache.spark.security.EncryptionFunSuite import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage._ +import org.apache.spark.util.io.ChunkedByteBuffer // Dummy class that creates a broadcast variable but doesn't use it class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { @@ -43,7 +47,7 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { } } -class BroadcastSuite extends SparkFunSuite with LocalSparkContext { +class BroadcastSuite extends SparkFunSuite with LocalSparkContext with EncryptionFunSuite { test("Using TorrentBroadcast locally") { sc = new SparkContext("local", "test") @@ -61,9 +65,8 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } - test("Accessing TorrentBroadcast variables in a local cluster") { + encryptionTest("Accessing TorrentBroadcast variables in a local cluster") { conf => val numSlaves = 4 - val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -85,7 +88,9 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val size = 1 + rand.nextInt(1024 * 10) val data: Array[Byte] = new Array[Byte](size) rand.nextBytes(data) - val blocks = blockifyObject(data, blockSize, serializer, compressionCodec) + val blocks = blockifyObject(data, blockSize, serializer, compressionCodec).map { b => + new ChunkedByteBuffer(b).toInputStream(dispose = true) + } val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec) assert(unblockified === data) } @@ -127,7 +132,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { val thrown = intercept[IllegalStateException] { sc.broadcast(Seq(1, 2, 3)) } - assert(thrown.getMessage.toLowerCase.contains("stopped")) + assert(thrown.getMessage.toLowerCase(Locale.ROOT).contains("stopped")) } test("Forbid broadcasting RDD directly") { @@ -137,6 +142,17 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } + encryptionTest("Cache broadcast to disk") { conf => + conf.setMaster("local") + .setAppName("test") + .set("spark.memory.useLegacyMode", "true") + .set("spark.storage.memoryFraction", "0.0") + sc = new SparkContext(conf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + assert(broadcast.value.sum === 10) + } + /** * Verify the persistence of state associated with a TorrentBroadcast in a local-cluster. * diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index c9b3d657c2b9d..f50cb38311db2 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -142,7 +142,7 @@ private[deploy] object IvyTestUtils { |} """.stripMargin val sourceFile = - new JavaSourceFromString(new File(dir, className).getAbsolutePath, contents) + new JavaSourceFromString(new File(dir, className).toURI.getPath, contents) createCompiledClass(className, dir, sourceFile, Seq.empty) } diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index 13cba94578a6a..005587051b6ad 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate -import org.apache.spark.util.ResetSystemProperties +import org.apache.spark.util.{ResetSystemProperties, Utils} class RPackageUtilsSuite extends SparkFunSuite @@ -74,9 +74,13 @@ class RPackageUtilsSuite val deps = Seq(dep1, dep2).mkString(",") IvyTestUtils.withRepository(main, Some(deps), None, withR = true) { repo => val jars = Seq(main, dep1, dep2).map(c => new JarFile(getJarPath(c, new File(new URI(repo))))) - assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") - assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") - assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + Utils.tryWithSafeFinally { + assert(RPackageUtils.checkManifestForR(jars(0)), "should have R code") + assert(!RPackageUtils.checkManifestForR(jars(1)), "should not have R code") + assert(!RPackageUtils.checkManifestForR(jars(2)), "should not have R code") + } { + jars.foreach(_.close()) + } } } @@ -131,7 +135,7 @@ class RPackageUtilsSuite test("SparkR zipping works properly") { val tempDir = Files.createTempDir() - try { + Utils.tryWithSafeFinally { IvyTestUtils.writeFile(tempDir, "test.R", "abc") val fakeSparkRDir = new File(tempDir, "SparkR") assert(fakeSparkRDir.mkdirs()) @@ -144,14 +148,19 @@ class RPackageUtilsSuite IvyTestUtils.writeFile(fakePackageDir, "DESCRIPTION", "abc") val finalZip = RPackageUtils.zipRLibraries(tempDir, "sparkr.zip") assert(finalZip.exists()) - val entries = new ZipFile(finalZip).entries().asScala.map(_.getName).toSeq - assert(entries.contains("/test.R")) - assert(entries.contains("/SparkR/abc.R")) - assert(entries.contains("/SparkR/DESCRIPTION")) - assert(!entries.contains("/package.zip")) - assert(entries.contains("/packageTest/def.R")) - assert(entries.contains("/packageTest/DESCRIPTION")) - } finally { + val zipFile = new ZipFile(finalZip) + Utils.tryWithSafeFinally { + val entries = zipFile.entries().asScala.map(_.getName).toSeq + assert(entries.contains("/test.R")) + assert(entries.contains("/SparkR/abc.R")) + assert(entries.contains("/SparkR/DESCRIPTION")) + assert(!entries.contains("/package.zip")) + assert(entries.contains("/packageTest/def.R")) + assert(entries.contains("/packageTest/DESCRIPTION")) + } { + zipFile.close() + } + } { FileUtils.deleteDirectory(tempDir) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala new file mode 100644 index 0000000000000..ab24a76e20a30 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.security.PrivilegedExceptionAction + +import scala.util.Random + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.permission.{FsAction, FsPermission} +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { + test("check file permission") { + import FsAction._ + val testUser = s"user-${Random.nextInt(100)}" + val testGroups = Array(s"group-${Random.nextInt(100)}") + val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) + + testUgi.doAs(new PrivilegedExceptionAction[Void] { + override def run(): Void = { + val sparkHadoopUtil = new SparkHadoopUtil + + // If file is owned by user and user has access permission + var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user but user has no access permission + status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + val otherUser = s"test-${Random.nextInt(100)}" + val otherGroup = s"test-${Random.nextInt(100)}" + + // If file is owned by user's group and user's group has access permission + status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user's group but user's group has no access permission + status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + // If file is owned by other user and this user has access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by other user but this user has no access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + null + } + }) + } + + private def fileStatus( + owner: String, + group: String, + userAction: FsAction, + groupAction: FsAction, + otherAction: FsAction): FileStatus = { + new FileStatus(0L, + false, + 0, + 0L, + 0L, + 0L, + new FsPermission(userAction, groupAction, otherAction), + owner, + group, + null) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c649e305a37e..a43839a8815f9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,8 +21,10 @@ import java.io._ import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -34,21 +36,12 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString -import org.apache.spark.util.{ResetSystemProperties, Utils} +import org.apache.spark.scheduler.EventLoggingListener +import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} -// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch -// of properties that needed to be cleared after tests. -class SparkSubmitSuite - extends SparkFunSuite - with Matchers - with BeforeAndAfterEach - with ResetSystemProperties - with Timeouts { - override def beforeEach() { - super.beforeEach() - System.setProperty("spark.testing", "true") - } +trait TestPrematureExit { + suite: SparkFunSuite => private val noOpOutputStream = new OutputStream { def write(b: Int) = {} @@ -65,16 +58,19 @@ class SparkSubmitSuite } /** Returns true if the script exits and the given search string is printed. */ - private def testPrematureExit(input: Array[String], searchString: String) = { + private[spark] def testPrematureExit( + input: Array[String], + searchString: String, + mainObject: CommandLineUtils = SparkSubmit) : Unit = { val printStream = new BufferPrintStream() - SparkSubmit.printStream = printStream + mainObject.printStream = printStream @volatile var exitedCleanly = false - SparkSubmit.exitFn = (_) => exitedCleanly = true + mainObject.exitFn = (_) => exitedCleanly = true val thread = new Thread { override def run() = try { - SparkSubmit.main(input) + mainObject.main(input) } catch { // If exceptions occur after the "exit" has happened, fine to ignore them. // These represent code paths not reachable during normal execution. @@ -88,6 +84,22 @@ class SparkSubmitSuite fail(s"Search string '$searchString' not found in $joined") } } +} + +// Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch +// of properties that needed to be cleared after tests. +class SparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts + with TestPrematureExit { + + override def beforeEach() { + super.beforeEach() + System.setProperty("spark.testing", "true") + } // scalastyle:off println test("prints usage on empty input") { @@ -139,6 +151,17 @@ class SparkSubmitSuite appArgs.childArgs should be (Seq("--master", "local", "some", "--weird", "args")) } + test("print the right queue name") { + val clArgs = Seq( + "--name", "myApp", + "--class", "Foo", + "--conf", "spark.yarn.queue=thequeue", + "userjar.jar") + val appArgs = new SparkSubmitArguments(clArgs) + appArgs.queue should be ("thequeue") + appArgs.toString should include ("thequeue") + } + test("specify deploy mode through configuration") { val clArgs = Seq( "--master", "yarn", @@ -204,7 +227,12 @@ class SparkSubmitSuite childArgsStr should include ("--arg arg1 --arg arg2") childArgsStr should include regex ("--jar .*thejar.jar") mainClass should be ("org.apache.spark.deploy.yarn.Client") - classpath should have length (0) + + // In yarn cluster mode, also adding jars to classpath + classpath(0) should endWith ("thejar.jar") + classpath(1) should endWith ("one.jar") + classpath(2) should endWith ("two.jar") + classpath(3) should endWith ("three.jar") sysProps("spark.executor.memory") should be ("5g") sysProps("spark.driver.memory") should be ("4g") @@ -379,6 +407,37 @@ class SparkSubmitSuite runSparkSubmit(args) } + test("launch simple application with spark-submit with redaction") { + val testDir = Utils.createTempDir() + testDir.deleteOnExit() + val testDirPath = new Path(testDir.getAbsolutePath()) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + try { + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", + "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.testing=true", + "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", + "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", + unusedJar.toString) + runSparkSubmit(args) + val listStatus = fileSystem.listStatus(testDirPath) + val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem) + Source.fromInputStream(logData).getLines().foreach { line => + assert(!line.contains("secret_password")) + } + } finally { + Utils.deleteRecursively(testDir) + } + } + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) @@ -452,7 +511,7 @@ class SparkSubmitSuite val tempDir = Utils.createTempDir() val srcDir = new File(tempDir, "sparkrtest") srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").getAbsolutePath, + val excSource = new JavaSourceFromString(new File(srcDir, "DummyClass").toURI.getPath, """package sparkrtest; | |public class DummyClass implements java.io.Serializable { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 4877710c1237d..266c9d33b5a96 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.deploy import java.io.{File, OutputStream, PrintStream} +import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import com.google.common.io.Files import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings -import org.apache.ivy.plugins.resolver.{AbstractResolver, FileSystemResolver, IBiblioResolver} +import org.apache.ivy.plugins.resolver.{AbstractResolver, ChainResolver, FileSystemResolver, IBiblioResolver} import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite @@ -66,22 +68,25 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { test("create repo resolvers") { val settings = new IvySettings - val res1 = SparkSubmitUtils.createRepoResolvers(None, settings) + val res1 = SparkSubmitUtils.createRepoResolvers(settings.getDefaultIvyUserDir) // should have central and spark-packages by default assert(res1.getResolvers.size() === 4) assert(res1.getResolvers.get(0).asInstanceOf[IBiblioResolver].getName === "local-m2-cache") assert(res1.getResolvers.get(1).asInstanceOf[FileSystemResolver].getName === "local-ivy-cache") assert(res1.getResolvers.get(2).asInstanceOf[IBiblioResolver].getName === "central") assert(res1.getResolvers.get(3).asInstanceOf[IBiblioResolver].getName === "spark-packages") + } + test("create additional resolvers") { val repos = "a/1,b/2,c/3" - val resolver2 = SparkSubmitUtils.createRepoResolvers(Option(repos), settings) - assert(resolver2.getResolvers.size() === 7) + val settings = SparkSubmitUtils.buildIvySettings(Option(repos), None) + val resolver = settings.getDefaultResolver.asInstanceOf[ChainResolver] + assert(resolver.getResolvers.size() === 4) val expected = repos.split(",").map(r => s"$r/") - resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (i < 3) { - assert(resolver.getName === s"repo-${i + 1}") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) + resolver.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => + if (1 < i && i < 3) { + assert(resolver.getName === s"repo-$i") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 1)) } } } @@ -126,8 +131,10 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val main = MavenCoordinate("my.awesome.lib", "mylib", "0.1") IvyTestUtils.withRepository(main, None, None) { repo => // end to end - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, Option(repo), - Option(tempIvyPath), isTest = true) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(Option(repo), Option(tempIvyPath)), + isTest = true) assert(jarPath.indexOf(tempIvyPath) >= 0, "should use non-default ivy path") } } @@ -137,7 +144,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val dep = "my.great.dep:mydep:0.5" // Local M2 repository IvyTestUtils.withRepository(main, Some(dep), Some(SparkSubmitUtils.m2Path)) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(None, None), isTest = true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") assert(jarPath.indexOf("mydep") >= 0, "should find dependency") @@ -146,7 +155,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val settings = new IvySettings val ivyLocal = new File(settings.getDefaultIvyUserDir, "local" + File.separator) IvyTestUtils.withRepository(main, Some(dep), Some(ivyLocal), useIvyLayout = true) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, None, + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(None, None), isTest = true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") assert(jarPath.indexOf("mydep") >= 0, "should find dependency") @@ -156,8 +167,10 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { settings.setDefaultIvyUserDir(new File(tempIvyPath)) IvyTestUtils.withRepository(main, Some(dep), Some(dummyIvyLocal), useIvyLayout = true, ivySettings = settings) { repo => - val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, None, - Some(tempIvyPath), isTest = true) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(None, Some(tempIvyPath)), + isTest = true) assert(jarPath.indexOf("mylib") >= 0, "should find artifact") assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") assert(jarPath.indexOf("mydep") >= 0, "should find dependency") @@ -166,7 +179,10 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { test("dependency not found throws RuntimeException") { intercept[RuntimeException] { - SparkSubmitUtils.resolveMavenCoordinates("a:b:c", None, None, isTest = true) + SparkSubmitUtils.resolveMavenCoordinates( + "a:b:c", + SparkSubmitUtils.buildIvySettings(None, None), + isTest = true) } } @@ -178,12 +194,17 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") + ",org.apache.spark:spark-core_fake:1.2.0" - val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, isTest = true) + val path = SparkSubmitUtils.resolveMavenCoordinates( + coordinates, + SparkSubmitUtils.buildIvySettings(None, None), + isTest = true) assert(path === "", "should return empty path") val main = MavenCoordinate("org.apache.spark", "spark-streaming-kafka-assembly_2.10", "1.2.0") IvyTestUtils.withRepository(main, None, None) { repo => - val files = SparkSubmitUtils.resolveMavenCoordinates(coordinates + "," + main.toString, - Some(repo), None, isTest = true) + val files = SparkSubmitUtils.resolveMavenCoordinates( + coordinates + "," + main.toString, + SparkSubmitUtils.buildIvySettings(Some(repo), None), + isTest = true) assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") } } @@ -192,10 +213,49 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = "my.great.dep:mydep:0.5" IvyTestUtils.withRepository(main, Some(dep), None) { repo => - val files = SparkSubmitUtils.resolveMavenCoordinates(main.toString, - Some(repo), None, Seq("my.great.dep:mydep"), isTest = true) + val files = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(Some(repo), None), + Seq("my.great.dep:mydep"), + isTest = true) assert(files.indexOf(main.artifactId) >= 0, "Did not return artifact") assert(files.indexOf("my.great.dep") < 0, "Returned excluded artifact") } } + + test("load ivy settings file") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") + val dep = "my.great.dep:mydep:0.5" + val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) + val settingsText = + s""" + | + | + | + | + | + | + | + | + | + | + |""".stripMargin + + val settingsFile = new File(tempIvyPath, "ivysettings.xml") + Files.write(settingsText, settingsFile, StandardCharsets.UTF_8) + val settings = SparkSubmitUtils.loadIvySettings(settingsFile.toString, None, None) + settings.setDefaultIvyUserDir(new File(tempIvyPath)) // NOTE - can't set this through file + + val testUtilSettings = new IvySettings + testUtilSettings.setDefaultIvyUserDir(new File(tempIvyPath)) + IvyTestUtils.withRepository(main, Some(dep), Some(dummyIvyLocal), useIvyLayout = true, + ivySettings = testUtilSettings) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, settings, isTest = true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf(tempIvyPath) >= 0, "should be in new ivy path") + assert(jarPath.indexOf("mydep") >= 0, "should find dependency") + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index e29eb8552e134..bf7480d79f8a1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.deploy import scala.collection.mutable import scala.concurrent.duration._ -import org.mockito.Mockito.{mock, when} +import org.mockito.Matchers.any +import org.mockito.Mockito.{mock, verify, when} import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ @@ -29,10 +30,11 @@ import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMaste import org.apache.spark.deploy.master.ApplicationInfo import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.Worker +import org.apache.spark.internal.config import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster._ -import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RegisterExecutor +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RegisterExecutorFailed} /** * End-to-end tests for dynamic allocation in standalone mode. @@ -354,12 +356,13 @@ class StandaloneDynamicAllocationSuite test("kill the same executor twice (SPARK-9795)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -378,12 +381,13 @@ class StandaloneDynamicAllocationSuite test("the pending replacement executors should not be lost (SPARK-10515)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -433,10 +437,11 @@ class StandaloneDynamicAllocationSuite assert(executors.size === 2) // simulate running a task on the executor - val getMap = PrivateMethod[mutable.HashMap[String, Int]]('executorIdToTaskCount) + val getMap = + PrivateMethod[mutable.HashMap[String, mutable.HashSet[Long]]]('executorIdToRunningTaskIds) val taskScheduler = sc.taskScheduler.asInstanceOf[TaskSchedulerImpl] - val executorIdToTaskCount = taskScheduler invokePrivate getMap() - executorIdToTaskCount(executors.head) = 1 + val executorIdToRunningTaskIds = taskScheduler invokePrivate getMap() + executorIdToRunningTaskIds(executors.head) = mutable.HashSet(1L) // kill the busy executor without force; this should fail assert(killExecutor(sc, executors.head, force = false).isEmpty) apps = getApplications() @@ -466,6 +471,52 @@ class StandaloneDynamicAllocationSuite } } + test("kill all executors on localhost") { + sc = new SparkContext(appConf) + val appId = sc.applicationId + eventually(timeout(10.seconds), interval(10.millis)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.executors.size === 2) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + val beforeList = getApplications().head.executors.keys.toSet + assert(killExecutorsOnHost(sc, "localhost").equals(true)) + + syncExecutors(sc) + val afterList = getApplications().head.executors.keys.toSet + + eventually(timeout(10.seconds), interval(100.millis)) { + assert(beforeList.intersect(afterList).size == 0) + } + } + + test("executor registration on a blacklisted host must fail") { + sc = new SparkContext(appConf.set(config.BLACKLIST_ENABLED.key, "true")) + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + val message = RegisterExecutor("one", endpointRef, "blacklisted-host", 10, Map.empty) + + // Get "localhost" on a blacklist. + val taskScheduler = mock(classOf[TaskSchedulerImpl]) + when(taskScheduler.nodeBlacklist()).thenReturn(Set("blacklisted-host")) + when(taskScheduler.sc).thenReturn(sc) + sc.taskScheduler = taskScheduler + + // Create a fresh scheduler backend to blacklist "localhost". + sc.schedulerBackend.stop() + val backend = + new StandaloneSchedulerBackend(taskScheduler, sc, Array(masterRpcEnv.address.toSparkURL)) + backend.start() + + backend.driverEndpoint.ask[Boolean](message) + eventually(timeout(10.seconds), interval(100.millis)) { + verify(endpointRef).send(RegisterExecutorFailed(any())) + } + } + // =============================== // | Utility methods for testing | // =============================== @@ -498,7 +549,7 @@ class StandaloneDynamicAllocationSuite /** Get the Master state */ private def getMasterState: MasterStateResponse = { - master.self.askWithRetry[MasterStateResponse](RequestMasterState) + master.self.askSync[MasterStateResponse](RequestMasterState) } /** Get the applications that are active from Master */ @@ -527,6 +578,16 @@ class StandaloneDynamicAllocationSuite } } + /** Kill the executors on a given host. */ + private def killExecutorsOnHost(sc: SparkContext, host: String): Boolean = { + syncExecutors(sc) + sc.schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutorsOnHost(host) + case _ => fail("expected coarse grained scheduler") + } + } + /** * Return a list of executor IDs belonging to this application. * @@ -561,7 +622,7 @@ class StandaloneDynamicAllocationSuite when(endpointRef.address).thenReturn(mockAddress) val message = RegisterExecutor(id, endpointRef, "localhost", 10, Map.empty) val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] - backend.driverEndpoint.askWithRetry[Boolean](message) + backend.driverEndpoint.askSync[Boolean](message) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index bc58fb2a362a4..936639b845789 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -171,7 +171,7 @@ class AppClientSuite /** Get the Master state */ private def getMasterState: MasterStateResponse = { - master.self.askWithRetry[MasterStateResponse](RequestMasterState) + master.self.askSync[MasterStateResponse](RequestMasterState) } /** Get the applications that are active from Master */ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index e3304be792af7..871c87415d35d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -177,7 +177,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar ended: Long): SparkUI = { val info = new ApplicationInfo(name, name, Some(1), Some(1), Some(1), Some(64), Seq(new AttemptInfo(attemptId, new Date(started), new Date(ended), - new Date(ended), ended - started, "user", completed))) + new Date(ended), ended - started, "user", completed, org.apache.spark.SPARK_VERSION))) val ui = mock[SparkUI] when(ui.getApplicationInfoList).thenReturn(List(info).iterator) when(ui.getAppName).thenReturn(name) @@ -253,7 +253,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar assertNotFound(appId, None) } - test("Test that if an attempt ID is is set, it must be used in lookups") { + test("Test that if an attempt ID is set, it must be used in lookups") { val operations = new StubCacheOperations() val clock = new ManualClock(1) implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index a5eda7b5a5a75..9b3e4ec793825 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -35,10 +36,11 @@ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ +import org.apache.spark.security.GroupMappingServiceProvider import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -46,7 +48,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc private var testDir: File = null before { - testDir = Utils.createTempDir() + testDir = Utils.createTempDir(namePrefix = s"a b%20c+d") } after { @@ -66,7 +68,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } test("Parse application logs") { - val provider = new FsHistoryProvider(createTestConf()) + val clock = new ManualClock(12345678) + val provider = new FsHistoryProvider(createTestConf(), clock) // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) @@ -106,15 +109,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc user: String, completed: Boolean): ApplicationHistoryInfo = { ApplicationHistoryInfo(id, name, - List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) + List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed, ""))) } + // For completed files, lastUpdated would be lastModified time. list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L, newAppComplete.lastModified(), "test", true)) list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + + // For Inprogress files, lastUpdated would be current loading time. list(2) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, - newAppIncomplete.lastModified(), "test", false)) + clock.getTimeMillis(), "test", false)) // Make sure the UI can be rendered. list.foreach { case info => @@ -125,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("SPARK-3697: ignore directories that cannot be read.") { + test("SPARK-3697: ignore files that cannot be read.") { // setReadable(...) does not work on Windows. Please refer JDK-6728842. assume(!Utils.isWindows) + + class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { + var mergeApplicationListingCall = 0 + override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + super.mergeApplicationListing(fileStatus) + mergeApplicationListingCall += 1 + } + } + val provider = new TestFsHistoryProvider + val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), @@ -140,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) logFile2.setReadable(false, false) - val provider = new FsHistoryProvider(createTestConf()) updateAndCheck(provider) { list => list.size should be (1) } + + provider.mergeApplicationListingCall should be (1) } test("history file is renamed from inprogress to completed") { @@ -299,6 +316,48 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc assert(!log2.exists()) } + test("log cleaner for inProgress files") { + val firstFileModifiedTime = TimeUnit.SECONDS.toMillis(10) + val secondFileModifiedTime = TimeUnit.SECONDS.toMillis(20) + val maxAge = TimeUnit.SECONDS.toMillis(40) + val clock = new ManualClock(0) + val provider = new FsHistoryProvider( + createTestConf().set("spark.history.fs.cleaner.maxAge", s"${maxAge}ms"), clock) + + val log1 = newLogFile("inProgressApp1", None, inProgress = true) + writeFile(log1, true, None, + SparkListenerApplicationStart( + "inProgressApp1", Some("inProgressApp1"), 3L, "test", Some("attempt1")) + ) + + clock.setTime(firstFileModifiedTime) + provider.checkForLogs() + + val log2 = newLogFile("inProgressApp2", None, inProgress = true) + writeFile(log2, true, None, + SparkListenerApplicationStart( + "inProgressApp2", Some("inProgressApp2"), 23L, "test2", Some("attempt2")) + ) + + clock.setTime(secondFileModifiedTime) + provider.checkForLogs() + + // This should not trigger any cleanup + updateAndCheck(provider)(list => list.size should be(2)) + + // Should trigger cleanup for first file but not second one + clock.setTime(firstFileModifiedTime + maxAge + 1) + updateAndCheck(provider)(list => list.size should be(1)) + assert(!log1.exists()) + assert(log2.exists()) + + // Should cleanup the second file as well. + clock.setTime(secondFileModifiedTime + maxAge + 1) + updateAndCheck(provider)(list => list.size should be(0)) + assert(!log1.exists()) + assert(!log2.exists()) + } + test("Event log copy") { val provider = new FsHistoryProvider(createTestConf()) val logs = (1 to 2).map { i => @@ -428,6 +487,102 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("support history server ui admin acls") { + def createAndCheck(conf: SparkConf, properties: (String, String)*) + (checkFn: SecurityManager => Unit): Unit = { + // Empty the testDir for each test. + if (testDir.exists() && testDir.isDirectory) { + testDir.listFiles().foreach { f => if (f.isFile) f.delete() } + } + + var provider: FsHistoryProvider = null + try { + provider = new FsHistoryProvider(conf) + val log = newLogFile("app1", Some("attempt1"), inProgress = false) + writeFile(log, true, None, + SparkListenerApplicationStart("app1", Some("app1"), System.currentTimeMillis(), + "test", Some("attempt1")), + SparkListenerEnvironmentUpdate(Map( + "Spark Properties" -> properties.toSeq, + "JVM Information" -> Seq.empty, + "System Properties" -> Seq.empty, + "Classpath Entries" -> Seq.empty + )), + SparkListenerApplicationEnd(System.currentTimeMillis())) + + provider.checkForLogs() + val appUi = provider.getAppUI("app1", Some("attempt1")) + + assert(appUi.nonEmpty) + val securityManager = appUi.get.ui.securityManager + checkFn(securityManager) + } finally { + if (provider != null) { + provider.stop() + } + } + } + + // Test both history ui admin acls and application acls are configured. + val conf1 = createTestConf() + .set("spark.history.ui.acls.enable", "true") + .set("spark.history.ui.admin.acls", "user1,user2") + .set("spark.history.ui.admin.acls.groups", "group1") + .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) + + createAndCheck(conf1, ("spark.admin.acls", "user"), ("spark.admin.acls.groups", "group")) { + securityManager => + // Test whether user has permission to access UI. + securityManager.checkUIViewPermissions("user1") should be (true) + securityManager.checkUIViewPermissions("user2") should be (true) + securityManager.checkUIViewPermissions("user") should be (true) + securityManager.checkUIViewPermissions("abc") should be (false) + + // Test whether user with admin group has permission to access UI. + securityManager.checkUIViewPermissions("user3") should be (true) + securityManager.checkUIViewPermissions("user4") should be (true) + securityManager.checkUIViewPermissions("user5") should be (true) + securityManager.checkUIViewPermissions("user6") should be (false) + } + + // Test only history ui admin acls are configured. + val conf2 = createTestConf() + .set("spark.history.ui.acls.enable", "true") + .set("spark.history.ui.admin.acls", "user1,user2") + .set("spark.history.ui.admin.acls.groups", "group1") + .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) + createAndCheck(conf2) { securityManager => + // Test whether user has permission to access UI. + securityManager.checkUIViewPermissions("user1") should be (true) + securityManager.checkUIViewPermissions("user2") should be (true) + // Check the unknown "user" should return false + securityManager.checkUIViewPermissions("user") should be (false) + + // Test whether user with admin group has permission to access UI. + securityManager.checkUIViewPermissions("user3") should be (true) + securityManager.checkUIViewPermissions("user4") should be (true) + // Check the "user5" without mapping relation should return false + securityManager.checkUIViewPermissions("user5") should be (false) + } + + // Test neither history ui admin acls nor application acls are configured. + val conf3 = createTestConf() + .set("spark.history.ui.acls.enable", "true") + .set("spark.user.groups.mapping", classOf[TestGroupsMappingProvider].getName) + createAndCheck(conf3) { securityManager => + // Test whether user has permission to access UI. + securityManager.checkUIViewPermissions("user1") should be (false) + securityManager.checkUIViewPermissions("user2") should be (false) + securityManager.checkUIViewPermissions("user") should be (false) + + // Test whether user with admin group has permission to access UI. + // Check should be failed since we don't have acl group settings. + securityManager.checkUIViewPermissions("user3") should be (false) + securityManager.checkUIViewPermissions("user4") should be (false) + securityManager.checkUIViewPermissions("user5") should be (false) + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -449,8 +604,14 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) val bstream = new BufferedOutputStream(cstream) if (isNewFormat) { - EventLoggingListener.initEventLog(new FileOutputStream(file)) + val newFormatStream = new FileOutputStream(file) + Utils.tryWithSafeFinally { + EventLoggingListener.initEventLog(newFormatStream, false, null) + } { + newFormatStream.close() + } } + val writer = new OutputStreamWriter(bstream, StandardCharsets.UTF_8) Utils.tryWithSafeFinally { events.foreach(e => writer.write(compact(render(JsonProtocol.sparkEventToJson(e))) + "\n")) @@ -480,3 +641,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + +class TestGroupsMappingProvider extends GroupMappingServiceProvider { + private val mappings = Map( + "user3" -> "group1", + "user4" -> "group1", + "user5" -> "group") + + override def getGroups(username: String): Set[String] = { + mappings.get(username).map(Set(_)).getOrElse(Set.empty) + } +} + diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a595bc174a310..95acb9a54440f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -20,7 +20,8 @@ import java.io.{File, FileInputStream, FileWriter, InputStream, IOException} import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets import java.util.zip.ZipInputStream -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import javax.servlet._ +import javax.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse} import scala.concurrent.duration._ import scala.language.postfixOps @@ -29,6 +30,8 @@ import com.codahale.metrics.Counter import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.eclipse.jetty.proxy.ProxyServlet +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods._ @@ -66,14 +69,15 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private var server: HistoryServer = null private var port: Int = -1 - def init(): Unit = { + def init(extraConf: (String, String)*): Unit = { val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() - val securityManager = new SecurityManager(conf) + val securityManager = HistoryServer.createSecurityManager(conf) server = new HistoryServer(conf, provider, securityManager, 18080) server.initialize() @@ -100,6 +104,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "minDate app list json" -> "applications?minDate=2015-02-10", "maxDate app list json" -> "applications?maxDate=2015-02-10", "maxDate2 app list json" -> "applications?maxDate=2015-02-03T16:42:40.000GMT", + "minEndDate app list json" -> "applications?minEndDate=2015-05-06T13:03:00.950GMT", + "maxEndDate app list json" -> "applications?maxEndDate=2015-05-06T13:03:00.950GMT", + "minEndDate and maxEndDate app list json" -> + "applications?minEndDate=2015-03-16&maxEndDate=2015-05-06T13:03:00.950GMT", + "minDate and maxEndDate app list json" -> + "applications?minDate=2015-03-16&maxEndDate=2015-05-06T13:03:00.950GMT", "limit app list json" -> "applications?limit=3", "one app json" -> "applications/local-1422981780767", "one app multi-attempt json" -> "applications/local-1426533911241", @@ -141,7 +151,10 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers "stage task list from multi-attempt app json(2)" -> "applications/local-1426533911241/2/stages/0/0/taskList", - "rdd list storage json" -> "applications/local-1422981780767/storage/rdd" + "rdd list storage json" -> "applications/local-1422981780767/storage/rdd", + "executor node blacklisting" -> "applications/app-20161116163331-0000/executors", + "executor node blacklisting unblacklisting" -> "applications/app-20161115172038-0000/executors", + "executor memory usage" -> "applications/app-20161116163331-0000/executors" // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845 // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0" ) @@ -258,8 +271,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getContentAndCode("foobar")._1 should be (HttpServletResponse.SC_NOT_FOUND) } - test("relative links are prefixed with uiRoot (spark.ui.proxyBase)") { - val proxyBaseBeforeTest = System.getProperty("spark.ui.proxyBase") + test("static relative links are prefixed with uiRoot (spark.ui.proxyBase)") { val uiRoot = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("/testwebproxybase") val page = new HistoryPage(server) val request = mock[HttpServletRequest] @@ -267,7 +279,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // when System.setProperty("spark.ui.proxyBase", uiRoot) val response = page.render(request) - System.setProperty("spark.ui.proxyBase", Option(proxyBaseBeforeTest).getOrElse("")) // then val urls = response \\ "@href" map (_.toString) @@ -275,6 +286,91 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers all (siteRelativeLinks) should startWith (uiRoot) } + test("ajax rendered relative links are prefixed with uiRoot (spark.ui.proxyBase)") { + val uiRoot = "/testwebproxybase" + System.setProperty("spark.ui.proxyBase", uiRoot) + + server.stop() + + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", logDir) + .set("spark.history.fs.update.interval", "0") + .set("spark.testing", "true") + + provider = new FsHistoryProvider(conf) + provider.checkForLogs() + val securityManager = HistoryServer.createSecurityManager(conf) + + server = new HistoryServer(conf, provider, securityManager, 18080) + server.initialize() + server.bind() + + val port = server.boundPort + + val servlet = new ProxyServlet { + override def rewriteTarget(request: HttpServletRequest): String = { + // servlet acts like a proxy that redirects calls made on + // spark.ui.proxyBase context path to the normal servlet handlers operating off "/" + val sb = request.getRequestURL() + + if (request.getQueryString() != null) { + sb.append(s"?${request.getQueryString()}") + } + + val proxyidx = sb.indexOf(uiRoot) + sb.delete(proxyidx, proxyidx + uiRoot.length).toString + } + } + + val contextHandler = new ServletContextHandler + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(uiRoot) + contextHandler.addServlet(holder, "/") + server.attachHandler(contextHandler) + + implicit val webDriver: WebDriver = new HtmlUnitDriver(true) { + getWebClient.getOptions.setThrowExceptionOnScriptError(false) + } + + try { + val url = s"http://localhost:$port" + + go to s"$url$uiRoot" + + // expect the ajax call to finish in 5 seconds + implicitlyWait(org.scalatest.time.Span(5, org.scalatest.time.Seconds)) + + // once this findAll call returns, we know the ajax load of the table completed + findAll(ClassNameQuery("odd")) + + val links = findAll(TagNameQuery("a")) + .map(_.attribute("href")) + .filter(_.isDefined) + .map(_.get) + .filter(_.startsWith(url)).toList + + // there are atleast some URL links that were generated via javascript, + // and they all contain the spark.ui.proxyBase (uiRoot) + links.length should be > 4 + all(links) should startWith(url + uiRoot) + } finally { + contextHandler.stop() + quit() + } + + } + + /** + * Verify that the security manager needed for the history server can be instantiated + * when `spark.authenticate` is `true`, rather than raise an `IllegalArgumentException`. + */ + test("security manager starts with spark.authenticate set") { + val conf = new SparkConf() + .set("spark.testing", "true") + .set(SecurityManager.SPARK_AUTH_CONF, "true") + HistoryServer.createSecurityManager(conf) + } + test("incomplete apps get refreshed") { implicit val webDriver: WebDriver = new HtmlUnitDriver @@ -294,7 +390,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set("spark.history.cache.window", "250ms") .remove("spark.testing") val provider = new FsHistoryProvider(myConf) - val securityManager = new SecurityManager(myConf) + val securityManager = HistoryServer.createSecurityManager(myConf) sc = new SparkContext("local", "test", myConf) val logDirUri = logDir.toURI @@ -469,8 +565,43 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); + logDir.deleteOnExit() + } + + test("ui and api authorization checks") { + val appId = "local-1430917381535" + val owner = "irashid" + val admin = "root" + val other = "alice" + + stop() + init( + "spark.ui.filters" -> classOf[FakeAuthFilter].getName(), + "spark.history.ui.acls.enable" -> "true", + "spark.history.ui.admin.acls" -> admin) + + val tests = Seq( + (owner, HttpServletResponse.SC_OK), + (admin, HttpServletResponse.SC_OK), + (other, HttpServletResponse.SC_FORBIDDEN), + // When the remote user is null, the code behaves as if auth were disabled. + (null, HttpServletResponse.SC_OK)) + val port = server.boundPort + val testUrls = Seq( + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") + + tests.foreach { case (user, expectedCode) => + testUrls.foreach { url => + val headers = if (user != null) Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) else Nil + val sc = TestUtils.httpResponseCode(new URL(url), headers = headers) + assert(sc === expectedCode, s"Unexpected status code $sc for $url (user = $user)") + } + } } def getContentAndCode(path: String, port: Int = port): (Int, Option[String], Option[String]) = { @@ -555,3 +686,26 @@ object HistoryServerSuite { } } } + +/** + * A filter used for auth tests; sets the request's user to the value of the "HTTP_USER" header. + */ +class FakeAuthFilter extends Filter { + + override def destroy(): Unit = { } + + override def init(config: FilterConfig): Unit = { } + + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val wrapped = new HttpServletRequestWrapper(hreq) { + override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER) + } + chain.doFilter(wrapped, res) + } + +} + +object FakeAuthFilter { + val FAKE_HTTP_USER = "HTTP_USER" +} diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 831a7bcb12743..2127da48ece49 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -432,7 +432,7 @@ class MasterSuite extends SparkFunSuite val master = makeMaster() master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) eventually(timeout(10.seconds)) { - val masterState = master.self.askWithRetry[MasterStateResponse](RequestMasterState) + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") } @@ -447,7 +447,7 @@ class MasterSuite extends SparkFunSuite } }) - master.self.ask( + master.self.send( RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080")) val executors = (0 until 3).map { i => new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 683eeeeb6d661..efcad140350b9 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -17,46 +17,44 @@ package org.apache.spark.executor +import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer -import java.util.concurrent.CountDownLatch +import java.util.Properties +import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.collection.mutable.HashMap +import scala.collection.mutable.Map +import scala.concurrent.duration._ -import org.mockito.Matchers._ -import org.mockito.Mockito.{mock, when} +import org.mockito.ArgumentCaptor +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually +import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, Task} +import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.UninterruptibleThread -class ExecutorSuite extends SparkFunSuite { +class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { test("SPARK-15963: Catch `TaskKilledException` correctly in Executor.TaskRunner") { // mock some objects to make Executor.launchTask() happy val conf = new SparkConf val serializer = new JavaSerializer(conf) - val mockEnv = mock(classOf[SparkEnv]) - val mockRpcEnv = mock(classOf[RpcEnv]) - val mockMetricsSystem = mock(classOf[MetricsSystem]) - val mockMemoryManager = mock(classOf[MemoryManager]) - when(mockEnv.conf).thenReturn(conf) - when(mockEnv.serializer).thenReturn(serializer) - when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) - when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) - when(mockEnv.memoryManager).thenReturn(mockMemoryManager) - when(mockEnv.closureSerializer).thenReturn(serializer) - val serializedTask = - Task.serializeWithDependencies( - new FakeTask(0, 0), - HashMap[String, Long](), - HashMap[String, Long](), - serializer.newInstance()) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) + val taskDescription = createFakeTaskDescription(serializedTask) // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -78,7 +76,7 @@ class ExecutorSuite extends SparkFunSuite { val executorSuiteHelper = new ExecutorSuiteHelper - val mockExecutorBackend = mock(classOf[ExecutorBackend]) + val mockExecutorBackend = mock[ExecutorBackend] when(mockExecutorBackend.statusUpdate(any(), any(), any())) .thenAnswer(new Answer[Unit] { var firstTime = true @@ -94,8 +92,8 @@ class ExecutorSuite extends SparkFunSuite { val taskState = invocationOnMock.getArguments()(1).asInstanceOf[TaskState] executorSuiteHelper.taskState = taskState val taskEndReason = invocationOnMock.getArguments()(2).asInstanceOf[ByteBuffer] - executorSuiteHelper.testFailedReason - = serializer.newInstance().deserialize(taskEndReason) + executorSuiteHelper.testFailedReason = + serializer.newInstance().deserialize(taskEndReason) // let the main test thread check `taskState` and `testFailedReason` executorSuiteHelper.latch3.countDown() } @@ -104,19 +102,23 @@ class ExecutorSuite extends SparkFunSuite { var executor: Executor = null try { - executor = new Executor("id", "localhost", mockEnv, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread - executor.launchTask(mockExecutorBackend, 0, 0, "", serializedTask) + executor.launchTask(mockExecutorBackend, taskDescription) - executorSuiteHelper.latch1.await() + if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { + fail("executor did not send first status update in time") + } // we know the task will be started, but not yet deserialized, because of the latches we // use in mockExecutorBackend. - executor.killAllTasks(true) + executor.killAllTasks(true, "test") executorSuiteHelper.latch2.countDown() - executorSuiteHelper.latch3.await() + if (!executorSuiteHelper.latch3.await(5, TimeUnit.SECONDS)) { + fail("executor did not send second status update in time") + } // `testFailedReason` should be `TaskKilled`; `taskState` should be `KILLED` - assert(executorSuiteHelper.testFailedReason === TaskKilled) + assert(executorSuiteHelper.testFailedReason === TaskKilled("test")) assert(executorSuiteHelper.taskState === TaskState.KILLED) } finally { @@ -125,6 +127,216 @@ class ExecutorSuite extends SparkFunSuite { } } } + + test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") { + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides + // the fetch failure. The executor should still tell the driver that the task failed due to a + // fetch failure, not a generic exception from user code. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + assert(failReason.isInstanceOf[FetchFailed]) + } + + test("Executor's worker threads should be UninterruptibleThread") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor thread test") + .set("spark.ui.enabled", "false") + sc = new SparkContext(conf) + val executorThread = sc.parallelize(Seq(1), 1).map { _ => + Thread.currentThread.getClass.getName + }.collect().head + assert(executorThread === classOf[UninterruptibleThread].getName) + } + + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it + // may be a false positive. And we should call the uncaught exception handler. + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat + // the fetch failure as a false positive, and just do normal OOM handling. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val (failReason, uncaughtExceptionHandler) = + runTaskGetFailReasonAndExceptionHandler(taskDescription) + // make sure the task failure just looks like a OOM, not a fetch failure + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) + } + + test("Gracefully handle error in task deserialization") { + val conf = new SparkConf + val serializer = new JavaSerializer(conf) + val env = createMockEnv(conf, serializer) + val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) + val taskDescription = createFakeTaskDescription(serializedTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + failReason match { + case ef: ExceptionFailure => + assert(ef.exception.isDefined) + assert(ef.exception.get.getMessage() === NonDeserializableTask.errorMsg) + case _ => + fail(s"unexpected failure type: $failReason") + } + } + + private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = { + val mockEnv = mock[SparkEnv] + val mockRpcEnv = mock[RpcEnv] + val mockMetricsSystem = mock[MetricsSystem] + val mockMemoryManager = mock[MemoryManager] + when(mockEnv.conf).thenReturn(conf) + when(mockEnv.serializer).thenReturn(serializer) + when(mockEnv.rpcEnv).thenReturn(mockRpcEnv) + when(mockEnv.metricsSystem).thenReturn(mockMetricsSystem) + when(mockEnv.memoryManager).thenReturn(mockMemoryManager) + when(mockEnv.closureSerializer).thenReturn(serializer) + SparkEnv.set(mockEnv) + mockEnv + } + + private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + new TaskDescription( + taskId = 0, + attemptNumber = 0, + executorId = "", + name = "", + index = 0, + addedFiles = Map[String, Long](), + addedJars = Map[String, Long](), + properties = new Properties, + serializedTask) + } + + private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + } + + private def runTaskGetFailReasonAndExceptionHandler( + taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { + val mockBackend = mock[ExecutorBackend] + val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] + var executor: Executor = null + try { + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, + uncaughtExceptionHandler = mockUncaughtExceptionHandler) + // the task will be launched in a dedicated worker thread + executor.launchTask(mockBackend, taskDescription) + eventually(timeout(5.seconds), interval(10.milliseconds)) { + assert(executor.numRunningTasks === 0) + } + } finally { + if (executor != null) { + executor.stop() + } + } + val orderedMock = inOrder(mockBackend) + val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer]) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.RUNNING), statusCaptor.capture()) + orderedMock.verify(mockBackend) + .statusUpdate(meq(0L), meq(TaskState.FAILED), statusCaptor.capture()) + // first statusUpdate for RUNNING has empty data + assert(statusCaptor.getAllValues().get(0).remaining() === 0) + // second update is more interesting + val failureData = statusCaptor.getAllValues.get(1) + val failReason = + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + (failReason, mockUncaughtExceptionHandler) + } +} + +class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + new Iterator[Int] { + override def hasNext: Boolean = true + override def next(): Int = { + throw new FetchFailedException( + bmAddress = BlockManagerId("1", "hostA", 1234), + shuffleId = 0, + mapId = 0, + reduceId = 0, + message = "fake fetch failure" + ) + } + } + } + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } +} + +class SimplePartition extends Partition { + override def index: Int = 0 +} + +class FetchFailureHidingRDD( + sc: SparkContext, + val input: FetchFailureThrowingRDD, + throwOOM: Boolean) extends RDD[Int](input) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + val inItr = input.compute(split, context) + try { + Iterator(inItr.size) + } catch { + case t: Throwable => + if (throwOOM) { + throw new OutOfMemoryError("OOM while handling another exception") + } else { + throw new RuntimeException("User Exception that hides the original exception", t) + } + } + } + + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } } // Helps to test("SPARK-15963") @@ -137,3 +349,14 @@ private class ExecutorSuiteHelper { @volatile var taskState: TaskState = _ @volatile var testFailedReason: TaskFailedReason = _ } + +private class NonDeserializableTask extends FakeTask(0, 0) with Externalizable { + def writeExternal(out: ObjectOutput): Unit = {} + def readExternal(in: ObjectInput): Unit = { + throw new RuntimeException(NonDeserializableTask.errorMsg) + } +} + +private object NonDeserializableTask { + val errorMsg = "failure in deserialization" +} diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 91a96bdda6833..b72cd8be24206 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -17,11 +17,9 @@ package org.apache.spark.internal.config +import java.util.Locale import java.util.concurrent.TimeUnit -import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap - import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.network.util.ByteUnit import org.apache.spark.util.SparkConfWithEnv @@ -98,6 +96,21 @@ class ConfigEntrySuite extends SparkFunSuite { assert(conf.get(bytes) === 1L) } + test("conf entry: regex") { + val conf = new SparkConf() + val rConf = ConfigBuilder(testKey("regex")).regexConf.createWithDefault(".*".r) + + conf.set(rConf, "[0-9a-f]{8}".r) + assert(conf.get(rConf).toString === "[0-9a-f]{8}") + + conf.set(rConf.key, "[0-9a-f]{4}") + assert(conf.get(rConf).toString === "[0-9a-f]{4}") + + conf.set(rConf.key, "[.") + val e = intercept[IllegalArgumentException](conf.get(rConf)) + assert(e.getMessage.contains("regex should be a regex, but was")) + } + test("conf entry: string seq") { val conf = new SparkConf() val seq = ConfigBuilder(testKey("seq")).stringConf.toSequence.createWithDefault(Seq()) @@ -120,7 +133,7 @@ class ConfigEntrySuite extends SparkFunSuite { val conf = new SparkConf() val transformationConf = ConfigBuilder(testKey("transformation")) .stringConf - .transform(_.toLowerCase()) + .transform(_.toLowerCase(Locale.ROOT)) .createWithDefault("FOO") assert(conf.get(transformationConf) === "foo") @@ -128,6 +141,28 @@ class ConfigEntrySuite extends SparkFunSuite { assert(conf.get(transformationConf) === "bar") } + test("conf entry: checkValue()") { + def createEntry(default: Int): ConfigEntry[Int] = + ConfigBuilder(testKey("checkValue")) + .intConf + .checkValue(value => value >= 0, "value must be non-negative") + .createWithDefault(default) + + val conf = new SparkConf() + + val entry = createEntry(10) + conf.set(entry, -1) + val e1 = intercept[IllegalArgumentException] { + conf.get(entry) + } + assert(e1.getMessage == "value must be non-negative") + + val e2 = intercept[IllegalArgumentException] { + createEntry(-1) + } + assert(e2.getMessage == "value must be non-negative") + } + test("conf entry: valid values check") { val conf = new SparkConf() val enum = ConfigBuilder(testKey("enum")) @@ -218,4 +253,12 @@ class ConfigEntrySuite extends SparkFunSuite { testEntryRef(nullConf, ref(nullConf)) } + test("conf entry : default function") { + var data = 0 + val conf = new SparkConf() + val iConf = ConfigBuilder(testKey("intval")).intConf.createWithDefaultFunction(() => data) + assert(conf.get(iConf) === 0) + data = 2 + assert(conf.get(iConf) === 2) + } } diff --git a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala index cac15a1dc4414..c88cc13654ce5 100644 --- a/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/launcher/LauncherBackendSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark._ +import org.apache.spark.util.Utils class LauncherBackendSuite extends SparkFunSuite with Matchers { @@ -35,6 +36,8 @@ class LauncherBackendSuite extends SparkFunSuite with Matchers { tests.foreach { case (name, master) => test(s"$name: launcher handle") { + // The tests here are failed due to the cmd length limitation up to 8K on Windows. + assume(!Utils.isWindows) testWithMaster(master) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index f8054f5fd7701..5d522189a0c29 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutput import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.util.Utils @@ -61,7 +60,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext pw.close() // Path to tmpFile - tmpFilePath = "file://" + tmpFile.getAbsolutePath + tmpFilePath = tmpFile.toURI.toString } after { @@ -181,15 +180,12 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext sc.textFile(tmpFilePath, 4) .map(key => (key, 1)) .reduceByKey(_ + _) - .saveAsTextFile("file://" + tmpFile.getAbsolutePath) + .saveAsTextFile(tmpFile.toURI.toString) sc.listenerBus.waitUntilEmpty(500) assert(inputRead == numRecords) - // Only supported on newer Hadoop - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { - assert(outputWritten == numBuckets) - } + assert(outputWritten == numBuckets) assert(shuffleRead == shuffleWritten) } @@ -197,7 +193,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext val numPartitions = 2 val cartVector = 0 to 9 val cartFile = new File(tmpDir, getClass.getSimpleName + "_cart.txt") - val cartFilePath = "file://" + cartFile.getAbsolutePath + val cartFilePath = cartFile.toURI.toString // write files to disk so we can read them later. sc.parallelize(cartVector).saveAsTextFile(cartFilePath) @@ -262,57 +258,49 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext } test("output metrics on records written") { - // Only supported on newer Hadoop - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { - val file = new File(tmpDir, getClass.getSimpleName) - val filePath = "file://" + file.getAbsolutePath + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = file.toURI.toURL.toString - val records = runAndReturnRecordsWritten { - sc.parallelize(1 to numRecords).saveAsTextFile(filePath) - } - assert(records == numRecords) + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).saveAsTextFile(filePath) } + assert(records == numRecords) } test("output metrics on records written - new Hadoop API") { - // Only supported on newer Hadoop - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { - val file = new File(tmpDir, getClass.getSimpleName) - val filePath = "file://" + file.getAbsolutePath - - val records = runAndReturnRecordsWritten { - sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) - .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](filePath) - } - assert(records == numRecords) + val file = new File(tmpDir, getClass.getSimpleName) + val filePath = file.toURI.toURL.toString + + val records = runAndReturnRecordsWritten { + sc.parallelize(1 to numRecords).map(key => (key.toString, key.toString)) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](filePath) } + assert(records == numRecords) } test("output metrics when writing text file") { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { - val taskBytesWritten = new ArrayBuffer[Long]() - sc.addSparkListener(new SparkListener() { - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesWritten += taskEnd.taskMetrics.outputMetrics.bytesWritten - } - }) - - val rdd = sc.parallelize(Array("a", "b", "c", "d"), 2) - - try { - rdd.saveAsTextFile(outPath.toString) - sc.listenerBus.waitUntilEmpty(500) - assert(taskBytesWritten.length == 2) - val outFiles = fs.listStatus(outPath).filter(_.getPath.getName != "_SUCCESS") - taskBytesWritten.zip(outFiles).foreach { case (bytes, fileStatus) => - assert(bytes >= fileStatus.getLen) - } - } finally { - fs.delete(outPath, true) + val taskBytesWritten = new ArrayBuffer[Long]() + sc.addSparkListener(new SparkListener() { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + taskBytesWritten += taskEnd.taskMetrics.outputMetrics.bytesWritten + } + }) + + val rdd = sc.parallelize(Array("a", "b", "c", "d"), 2) + + try { + rdd.saveAsTextFile(outPath.toString) + sc.listenerBus.waitUntilEmpty(500) + assert(taskBytesWritten.length == 2) + val outFiles = fs.listStatus(outPath).filter(_.getPath.getName != "_SUCCESS") + taskBytesWritten.zip(outFiles).foreach { case (bytes, fileStatus) => + assert(bytes >= fileStatus.getLen) } + } finally { + fs.delete(outPath, true) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 022fe91edade9..fe8955840d72f 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -94,6 +94,20 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi } } + test("security with aes encryption") { + val conf = new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.app.id", "app-id") + .set("spark.network.crypto.enabled", "true") + .set("spark.network.crypto.saslFallback", "false") + testConnection(conf, conf) match { + case Success(_) => // expected + case Failure(t) => fail(t) + } + } + + /** * Creates two servers with different configurations and sees if they can talk. * Returns Success() if they can transfer a block, and Failure() if the block transfer was failed diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 121447a96529b..271ab8b148831 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -23,7 +23,6 @@ import org.mockito.Mockito.mock import org.scalatest._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} -import org.apache.spark.internal.config._ import org.apache.spark.network.BlockDataManager class NettyBlockTransferServiceSuite diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 58664e77d24a5..b29a53cffeb51 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -199,10 +199,9 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim val f = sc.parallelize(1 to 100, 4) .mapPartitions(itr => { Thread.sleep(20); itr }) .countAsync() - val e = intercept[SparkException] { + intercept[TimeoutException] { ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) } - assert(e.getCause.isInstanceOf[TimeoutException]) } private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index b0d69de6e2ef4..02df157be377c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -516,10 +516,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored") /* - Check that configurable formats get configured: - ConfigTestFormat throws an exception if we try to write - to it when setConf hasn't been called first. - Assertion is in ConfigTestFormat.getRecordWriter. + * Check that configurable formats get configured: + * ConfigTestFormat throws an exception if we try to write + * to it when setConf hasn't been called first. + * Assertion is in ConfigTestFormat.getRecordWriter. */ pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored") } @@ -544,7 +544,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val e = intercept[SparkException] { pairs.saveAsNewAPIHadoopFile[NewFakeFormatWithCallback]("ignored") } - assert(e.getMessage contains "failed to write") + assert(e.getCause.getMessage contains "failed to write") assert(FakeWriterWithCallback.calledBy === "write,callback,close") assert(FakeWriterWithCallback.exception != null, "exception should be captured") @@ -725,8 +725,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } /* - These classes are fakes for testing - "saveNewAPIHadoopFile should call setConf if format is configurable". + These classes are fakes for testing saveAsHadoopFile/saveNewAPIHadoopFile. Unfortunately, they have to be top level classes, and not defined in the test method, because otherwise Scala won't generate no-args constructors and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 7293aa9a2584f..1a0eb250e7cdc 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -21,8 +21,6 @@ import java.io.File import scala.collection.Map import scala.io.Codec -import scala.sys.process._ -import scala.util.Try import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{LongWritable, Text} @@ -32,109 +30,104 @@ import org.apache.spark._ import org.apache.spark.util.Utils class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { + val envCommand = if (Utils.isWindows) { + "cmd.exe /C set" + } else { + "printenv" + } test("basic pipe") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assume(TestUtils.testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat")) + val piped = nums.pipe(Seq("cat")) - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") - } else { - assert(true) - } + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") } test("basic pipe with tokenization") { - if (testCommandAvailable("wc")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + assume(TestUtils.testCommandAvailable("wc")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - // verify that both RDD.pipe(command: String) and RDD.pipe(command: String, env) work good - for (piped <- Seq(nums.pipe("wc -l"), nums.pipe("wc -l", Map[String, String]()))) { - val c = piped.collect() - assert(c.size === 2) - assert(c(0).trim === "2") - assert(c(1).trim === "2") - } - } else { - assert(true) + // verify that both RDD.pipe(command: String) and RDD.pipe(command: String, env) work good + for (piped <- Seq(nums.pipe("wc -l"), nums.pipe("wc -l", Map[String, String]()))) { + val c = piped.collect() + assert(c.size === 2) + assert(c(0).trim === "2") + assert(c(1).trim === "2") } } test("failure in iterating over pipe input") { - if (testCommandAvailable("cat")) { - val nums = - sc.makeRDD(Array(1, 2, 3, 4), 2) - .mapPartitionsWithIndex((index, iterator) => { - new Iterator[Int] { - def hasNext = true - def next() = { - throw new SparkException("Exception to simulate bad scenario") - } - } - }) + assume(TestUtils.testCommandAvailable("cat")) + val nums = + sc.makeRDD(Array(1, 2, 3, 4), 2) + .mapPartitionsWithIndex((index, iterator) => { + new Iterator[Int] { + def hasNext = true + def next() = { + throw new SparkException("Exception to simulate bad scenario") + } + } + }) - val piped = nums.pipe(Seq("cat")) + val piped = nums.pipe(Seq("cat")) - intercept[SparkException] { - piped.collect() - } + intercept[SparkException] { + piped.collect() } } test("advanced pipe") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val bl = sc.broadcast(List("0")) - - val piped = nums.pipe(Seq("cat"), + assume(TestUtils.testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val bl = sc.broadcast(List("0")) + + val piped = nums.pipe(Seq("cat"), + Map[String, String](), + (f: String => Unit) => { + bl.value.foreach(f); f("\u0001") + }, + (i: Int, f: String => Unit) => f(i + "_")) + + val c = piped.collect() + + assert(c.size === 8) + assert(c(0) === "0") + assert(c(1) === "\u0001") + assert(c(2) === "1_") + assert(c(3) === "2_") + assert(c(4) === "0") + assert(c(5) === "\u0001") + assert(c(6) === "3_") + assert(c(7) === "4_") + + val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) + val d = nums1.groupBy(str => str.split("\t")(0)). + pipe(Seq("cat"), Map[String, String](), (f: String => Unit) => { bl.value.foreach(f); f("\u0001") }, - (i: Int, f: String => Unit) => f(i + "_")) - - val c = piped.collect() - - assert(c.size === 8) - assert(c(0) === "0") - assert(c(1) === "\u0001") - assert(c(2) === "1_") - assert(c(3) === "2_") - assert(c(4) === "0") - assert(c(5) === "\u0001") - assert(c(6) === "3_") - assert(c(7) === "4_") - - val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2) - val d = nums1.groupBy(str => str.split("\t")(0)). - pipe(Seq("cat"), - Map[String, String](), - (f: String => Unit) => { - bl.value.foreach(f); f("\u0001") - }, - (i: Tuple2[String, Iterable[String]], f: String => Unit) => { - for (e <- i._2) { - f(e + "_") - } - }).collect() - assert(d.size === 8) - assert(d(0) === "0") - assert(d(1) === "\u0001") - assert(d(2) === "b\t2_") - assert(d(3) === "b\t4_") - assert(d(4) === "0") - assert(d(5) === "\u0001") - assert(d(6) === "a\t1_") - assert(d(7) === "a\t3_") - } else { - assert(true) - } + (i: Tuple2[String, Iterable[String]], f: String => Unit) => { + for (e <- i._2) { + f(e + "_") + } + }).collect() + assert(d.size === 8) + assert(d(0) === "0") + assert(d(1) === "\u0001") + assert(d(2) === "b\t2_") + assert(d(3) === "b\t4_") + assert(d(4) === "0") + assert(d(5) === "\u0001") + assert(d(6) === "a\t1_") + assert(d(7) === "a\t3_") } test("pipe with empty partition") { @@ -142,67 +135,67 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { val piped = data.pipe("wc -c") assert(piped.count == 8) val charCounts = piped.map(_.trim.toInt).collect().toSet - assert(Set(0, 4, 5) == charCounts) + val expected = if (Utils.isWindows) { + // Note that newline character on Windows is \r\n which are two. + Set(0, 5, 6) + } else { + Set(0, 4, 5) + } + assert(expected == charCounts) } test("pipe with env variable") { - if (testCommandAvailable("printenv")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA")) - val c = piped.collect() - assert(c.size === 2) - assert(c(0) === "LALALA") - assert(c(1) === "LALALA") - } else { - assert(true) - } + assume(TestUtils.testCommandAvailable(envCommand)) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(s"$envCommand MY_TEST_ENV", Map("MY_TEST_ENV" -> "LALALA")) + val c = piped.collect() + assert(c.length === 2) + // On Windows, `cmd.exe /C set` is used which prints out it as `varname=value` format + // whereas `printenv` usually prints out `value`. So, `varname=` is stripped here for both. + assert(c(0).stripPrefix("MY_TEST_ENV=") === "LALALA") + assert(c(1).stripPrefix("MY_TEST_ENV=") === "LALALA") } test("pipe with process which cannot be launched due to bad command") { - if (!testCommandAvailable("some_nonexistent_command")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val command = Seq("some_nonexistent_command") - val piped = nums.pipe(command) - val exception = intercept[SparkException] { - piped.collect() - } - assert(exception.getMessage.contains(command.mkString(" "))) + assume(!TestUtils.testCommandAvailable("some_nonexistent_command")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val command = Seq("some_nonexistent_command") + val piped = nums.pipe(command) + val exception = intercept[SparkException] { + piped.collect() } + assert(exception.getMessage.contains(command.mkString(" "))) } test("pipe with process which is launched but fails with non-zero exit status") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val command = Seq("cat", "nonexistent_file") - val piped = nums.pipe(command) - val exception = intercept[SparkException] { - piped.collect() - } - assert(exception.getMessage.contains(command.mkString(" "))) + assume(TestUtils.testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val command = Seq("cat", "nonexistent_file") + val piped = nums.pipe(command) + val exception = intercept[SparkException] { + piped.collect() } + assert(exception.getMessage.contains(command.mkString(" "))) } test("basic pipe with separate working directory") { - if (testCommandAvailable("cat")) { - val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) - val piped = nums.pipe(Seq("cat"), separateWorkingDir = true) - val c = piped.collect() - assert(c.size === 4) - assert(c(0) === "1") - assert(c(1) === "2") - assert(c(2) === "3") - assert(c(3) === "4") - val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true) - val collectPwd = pipedPwd.collect() - assert(collectPwd(0).contains("tasks/")) - val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true, bufferSize = 16384).collect() - // make sure symlinks were created - assert(pipedLs.length > 0) - // clean up top level tasks directory - Utils.deleteRecursively(new File("tasks")) - } else { - assert(true) - } + assume(TestUtils.testCommandAvailable("cat")) + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat"), separateWorkingDir = true) + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true) + val collectPwd = pipedPwd.collect() + assert(collectPwd(0).contains("tasks/")) + val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true, bufferSize = 16384).collect() + // make sure symlinks were created + assert(pipedLs.length > 0) + // clean up top level tasks directory + Utils.deleteRecursively(new File("tasks")) } test("test pipe exports map_input_file") { @@ -213,42 +206,36 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { testExportInputFile("mapreduce_map_input_file") } - def testCommandAvailable(command: String): Boolean = { - val attempt = Try(Process(command).run(ProcessLogger(_ => ())).exitValue()) - attempt.isSuccess && attempt.get == 0 - } - def testExportInputFile(varName: String) { - if (testCommandAvailable("printenv")) { - val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable], - classOf[Text], 2) { - override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition()) + assume(TestUtils.testCommandAvailable(envCommand)) + val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable], + classOf[Text], 2) { + override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition()) - override val getDependencies = List[Dependency[_]]() + override val getDependencies = List[Dependency[_]]() - override def compute(theSplit: Partition, context: TaskContext) = { - new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1), - new Text("b")))) - } + override def compute(theSplit: Partition, context: TaskContext) = { + new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1), + new Text("b")))) } - val hadoopPart1 = generateFakeHadoopPartition() - val pipedRdd = - new PipedRDD( - nums, - PipedRDD.tokenize("printenv " + varName), - Map(), - null, - null, - false, - 4092, - Codec.defaultCharsetCodec.name) - val tContext = TaskContext.empty() - val rddIter = pipedRdd.compute(hadoopPart1, tContext) - val arr = rddIter.toArray - assert(arr(0) == "/some/path") - } else { - // printenv isn't available so just pass the test } + val hadoopPart1 = generateFakeHadoopPartition() + val pipedRdd = + new PipedRDD( + nums, + PipedRDD.tokenize(s"$envCommand $varName"), + Map(), + null, + null, + false, + 4092, + Codec.defaultCharsetCodec.name) + val tContext = TaskContext.empty() + val rddIter = pipedRdd.compute(hadoopPart1, tContext) + val arr = rddIter.toArray + // On Windows, `cmd.exe /C set` is used which prints out it as `varname=value` format + // whereas `printenv` usually prints out `value`. So, `varname=` is stripped here for both. + assert(arr(0).stripPrefix(s"$varName=") === "/some/path") } def generateFakeHadoopPartition(): HadoopPartition = { diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a2..7f20206202cb9 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index acdf21df9a161..31d9dd3de8acc 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -118,8 +118,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") - val reply = newRpcEndpointRef.askWithRetry[String]("Echo") + val newRpcEndpointRef = rpcEndpointRef.askSync[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askSync[String]("Echo") assert("Echo" === reply) } @@ -132,7 +132,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { context.reply(msg) } }) - val reply = rpcEndpointRef.askWithRetry[String]("hello") + val reply = rpcEndpointRef.askSync[String]("hello") assert("hello" === reply) } @@ -150,7 +150,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { - val reply = rpcEndpointRef.askWithRetry[String]("hello") + val reply = rpcEndpointRef.askSync[String]("hello") assert("hello" === reply) } finally { anotherEnv.shutdown() @@ -177,14 +177,13 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // Use anotherEnv to find out the RpcEndpointRef val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") try { - // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause - val e = intercept[SparkException] { - rpcEndpointRef.askWithRetry[String]("hello", new RpcTimeout(1 millis, shortProp)) + val e = intercept[RpcTimeoutException] { + rpcEndpointRef.askSync[String]("hello", new RpcTimeout(1 millis, shortProp)) } // The SparkException cause should be a RpcTimeoutException with message indicating the // controlling timeout property - assert(e.getCause.isInstanceOf[RpcTimeoutException]) - assert(e.getCause.getMessage.contains(shortProp)) + assert(e.isInstanceOf[RpcTimeoutException]) + assert(e.getMessage.contains(shortProp)) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -637,11 +636,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(anotherEnv.address.port != env.address.port) } - test("send with authentication") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - + private def testSend(conf: SparkConf): Unit = { val localEnv = createRpcEnv(conf, "authentication-local", 0) val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) @@ -667,11 +662,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("ask with authentication") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - + private def testAsk(conf: SparkConf): Unit = { val localEnv = createRpcEnv(conf, "authentication-local", 0) val remoteEnv = createRpcEnv(conf, "authentication-remote", 0, clientMode = true) @@ -685,7 +676,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") - val reply = rpcEndpointRef.askWithRetry[String]("hello") + val reply = rpcEndpointRef.askSync[String]("hello") assert("hello" === reply) } finally { localEnv.shutdown() @@ -695,6 +686,48 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } + test("send with authentication") { + testSend(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good")) + } + + test("send with SASL encryption") { + testSend(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.authenticate.enableSaslEncryption", "true")) + } + + test("send with AES encryption") { + testSend(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.network.crypto.enabled", "true") + .set("spark.network.crypto.saslFallback", "false")) + } + + test("ask with authentication") { + testAsk(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good")) + } + + test("ask with SASL encryption") { + testAsk(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.authenticate.enableSaslEncryption", "true")) + } + + test("ask with AES encryption") { + testAsk(new SparkConf() + .set("spark.authenticate", "true") + .set("spark.authenticate.secret", "good") + .set("spark.network.crypto.enabled", "true") + .set("spark.network.crypto.saslFallback", "false")) + } + test("construct RpcTimeout with conf property") { val conf = new SparkConf @@ -860,7 +893,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val ref = anotherEnv.setupEndpointRef(env.address, "SPARK-14699") // Make sure the connect is set up - assert(ref.askWithRetry[String]("hello") === "hello") + assert(ref.askSync[String]("hello") === "hello") anotherEnv.shutdown() anotherEnv.awaitTermination() diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 0409aa3a5dee1..2b1bce4d208f6 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.rpc.netty +import org.scalatest.mock.MockitoSugar + import org.apache.spark._ +import org.apache.spark.network.client.TransportClient import org.apache.spark.rpc._ -class NettyRpcEnvSuite extends RpcEnvSuite { +class NettyRpcEnvSuite extends RpcEnvSuite with MockitoSugar { override def createRpcEnv( conf: SparkConf, @@ -53,4 +56,32 @@ class NettyRpcEnvSuite extends RpcEnvSuite { } } + test("RequestMessage serialization") { + def assertRequestMessageEquals(expected: RequestMessage, actual: RequestMessage): Unit = { + assert(expected.senderAddress === actual.senderAddress) + assert(expected.receiver === actual.receiver) + assert(expected.content === actual.content) + } + + val nettyEnv = env.asInstanceOf[NettyRpcEnv] + val client = mock[TransportClient] + val senderAddress = RpcAddress("locahost", 12345) + val receiverAddress = RpcEndpointAddress("localhost", 54321, "test") + val receiver = new NettyRpcEndpointRef(nettyEnv.conf, receiverAddress, nettyEnv) + + val msg = new RequestMessage(senderAddress, receiver, "foo") + assertRequestMessageEquals( + msg, + RequestMessage(nettyEnv, client, msg.serialize(nettyEnv))) + + val msg2 = new RequestMessage(null, receiver, "foo") + assertRequestMessageEquals( + msg2, + RequestMessage(nettyEnv, client, msg2.serialize(nettyEnv))) + + val msg3 = new RequestMessage(senderAddress, receiver, null) + assertRequestMessageEquals( + msg3, + RequestMessage(nettyEnv, client, msg3.serialize(nettyEnv))) + } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 0c156fef0ae0f..a71d8726e7066 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -34,7 +34,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) val sm = mock(classOf[StreamManager]) when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any())) - .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null)) + .thenReturn(new RequestMessage(RpcAddress("localhost", 12345), null, null)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index b2e7ec5df015c..2b18ebee79a2b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -17,10 +17,391 @@ package org.apache.spark.scheduler -import org.apache.spark.{SparkConf, SparkFunSuite} +import org.mockito.invocation.InvocationOnMock +import org.mockito.Matchers.any +import org.mockito.Mockito.{never, verify, when} +import org.mockito.stubbing.Answer +import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar + +import org.apache.spark._ import org.apache.spark.internal.config +import org.apache.spark.util.ManualClock + +class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with MockitoSugar + with LocalSparkContext { + + private val clock = new ManualClock(0) + + private var blacklist: BlacklistTracker = _ + private var listenerBusMock: LiveListenerBus = _ + private var scheduler: TaskSchedulerImpl = _ + private var conf: SparkConf = _ + + override def beforeEach(): Unit = { + conf = new SparkConf().setAppName("test").setMaster("local") + .set(config.BLACKLIST_ENABLED.key, "true") + scheduler = mockTaskSchedWithConf(conf) + + clock.setTime(0) + + listenerBusMock = mock[LiveListenerBus] + blacklist = new BlacklistTracker(listenerBusMock, conf, None, clock) + } + + override def afterEach(): Unit = { + if (blacklist != null) { + blacklist = null + } + if (scheduler != null) { + scheduler.stop() + scheduler = null + } + super.afterEach() + } + + // All executors and hosts used in tests should be in this set, so that [[assertEquivalentToSet]] + // works. Its OK if its got extraneous entries + val allExecutorAndHostIds = { + (('A' to 'Z')++ (1 to 100).map(_.toString)) + .flatMap{ suffix => + Seq(s"host$suffix", s"host-$suffix") + } + }.toSet + + /** + * Its easier to write our tests as if we could directly look at the sets of nodes & executors in + * the blacklist. However the api doesn't expose a set, so this is a simple way to test + * something similar, since we know the universe of values that might appear in these sets. + */ + def assertEquivalentToSet(f: String => Boolean, expected: Set[String]): Unit = { + allExecutorAndHostIds.foreach { id => + val actual = f(id) + val exp = expected.contains(id) + assert(actual === exp, raw"""for string "$id" """) + } + } + + def mockTaskSchedWithConf(conf: SparkConf): TaskSchedulerImpl = { + sc = new SparkContext(conf) + val scheduler = mock[TaskSchedulerImpl] + when(scheduler.sc).thenReturn(sc) + when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker) + scheduler + } + + def createTaskSetBlacklist(stageId: Int = 0): TaskSetBlacklist = { + new TaskSetBlacklist(conf, stageId, clock) + } + + test("executors can be blacklisted with only a few failures per stage") { + // For many different stages, executor 1 fails a task, then executor 2 succeeds the task, + // and then the task set is done. Not enough failures to blacklist the executor *within* + // any particular taskset, but we still blacklist the executor overall eventually. + // Also, we intentionally have a mix of task successes and failures -- there are even some + // successes after the executor is blacklisted. The idea here is those tasks get scheduled + // before the executor is blacklisted. We might get successes after blacklisting (because the + // executor might be flaky but not totally broken). But successes should not unblacklist the + // executor. + val failuresUntilBlacklisted = conf.get(config.MAX_FAILURES_PER_EXEC) + var failuresSoFar = 0 + (0 until failuresUntilBlacklisted * 10).foreach { stageId => + val taskSetBlacklist = createTaskSetBlacklist(stageId) + if (stageId % 2 == 0) { + // fail one task in every other taskset + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + failuresSoFar += 1 + } + blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) + assert(failuresSoFar == stageId / 2 + 1) + if (failuresSoFar < failuresUntilBlacklisted) { + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } else { + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + verify(listenerBusMock).post( + SparkListenerExecutorBlacklisted(0, "1", failuresUntilBlacklisted)) + } + } + } + + // If an executor has many task failures, but the task set ends up failing, it shouldn't be + // counted against the executor. + test("executors aren't blacklisted as a result of tasks in failed task sets") { + val failuresUntilBlacklisted = conf.get(config.MAX_FAILURES_PER_EXEC) + // for many different stages, executor 1 fails a task, and then the taskSet fails. + (0 until failuresUntilBlacklisted * 10).foreach { stage => + val taskSetBlacklist = createTaskSetBlacklist(stage) + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + } + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + + Seq(true, false).foreach { succeedTaskSet => + val label = if (succeedTaskSet) "success" else "failure" + test(s"stage blacklist updates correctly on stage $label") { + // Within one taskset, an executor fails a few times, so it's blacklisted for the taskset. + // But if the taskset fails, we shouldn't blacklist the executor after the stage. + val taskSetBlacklist = createTaskSetBlacklist(0) + // We trigger enough failures for both the taskset blacklist, and the application blacklist. + val numFailures = math.max(conf.get(config.MAX_FAILURES_PER_EXEC), + conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)) + (0 until numFailures).foreach { index => + taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = index) + } + assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + if (succeedTaskSet) { + // The task set succeeded elsewhere, so we should count those failures against our executor, + // and it should be blacklisted for the entire application. + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", numFailures)) + } else { + // The task set failed, so we don't count these failures against the executor for other + // stages. + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + } + } + + test("blacklisted executors and nodes get recovered with time") { + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole + // application. + (0 until 4).foreach { partition => + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 4)) + + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole + // application. Since that's the second executor that is blacklisted on the same node, we also + // blacklist that node. + (0 until 4).foreach { partition => + taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) + assert(blacklist.nodeBlacklist() === Set("hostA")) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", 2)) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 4)) + + // Advance the clock and then make sure hostA and executors 1 and 2 have been removed from the + // blacklist. + val timeout = blacklist.BLACKLIST_TIMEOUT_MILLIS + 1 + clock.advance(timeout) + blacklist.applyBlacklistTimeout() + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(timeout, "2")) + verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(timeout, "1")) + verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(timeout, "hostA")) + + // Fail one more task, but executor isn't put back into blacklist since the count of failures + // on that executor should have been reset to 0. + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + blacklist.updateBlacklistForSuccessfulTaskSet(2, 0, taskSetBlacklist2.execToFailures) + assert(blacklist.nodeBlacklist() === Set()) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + + test("blacklist can handle lost executors") { + // The blacklist should still work if an executor is killed completely. We should still + // be able to blacklist the entire node. + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + // Lets say that executor 1 dies completely. We get some task failures, but + // the taskset then finishes successfully (elsewhere). + (0 until 4).foreach { partition => + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + } + blacklist.handleRemovedExecutor("1") + blacklist.updateBlacklistForSuccessfulTaskSet( + stageId = 0, + stageAttemptId = 0, + taskSetBlacklist0.execToFailures) + assert(blacklist.isExecutorBlacklisted("1")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 4)) + val t1 = blacklist.BLACKLIST_TIMEOUT_MILLIS / 2 + clock.advance(t1) -class BlacklistTrackerSuite extends SparkFunSuite { + // Now another executor gets spun up on that host, but it also dies. + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + (0 until 4).foreach { partition => + taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + } + blacklist.handleRemovedExecutor("2") + blacklist.updateBlacklistForSuccessfulTaskSet( + stageId = 1, + stageAttemptId = 0, + taskSetBlacklist1.execToFailures) + // We've now had two bad executors on the hostA, so we should blacklist the entire node. + assert(blacklist.isExecutorBlacklisted("1")) + assert(blacklist.isExecutorBlacklisted("2")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(t1, "2", 4)) + assert(blacklist.isNodeBlacklisted("hostA")) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(t1, "hostA", 2)) + + // Advance the clock so that executor 1 should no longer be explicitly blacklisted, but + // everything else should still be blacklisted. + val t2 = blacklist.BLACKLIST_TIMEOUT_MILLIS / 2 + 1 + clock.advance(t2) + blacklist.applyBlacklistTimeout() + assert(!blacklist.isExecutorBlacklisted("1")) + verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(t1 + t2, "1")) + assert(blacklist.isExecutorBlacklisted("2")) + assert(blacklist.isNodeBlacklisted("hostA")) + // make sure we don't leak memory + assert(!blacklist.executorIdToBlacklistStatus.contains("1")) + assert(!blacklist.nodeToBlacklistedExecs("hostA").contains("1")) + // Advance the timeout again so now hostA should be removed from the blacklist. + clock.advance(t1) + blacklist.applyBlacklistTimeout() + assert(!blacklist.nodeIdToBlacklistExpiryTime.contains("hostA")) + verify(listenerBusMock).post(SparkListenerNodeUnblacklisted(t1 + t2 + t1, "hostA")) + // Even though unblacklisting a node implicitly unblacklists all of its executors, + // there will be no SparkListenerExecutorUnblacklisted sent here. + } + + test("task failures expire with time") { + // Verifies that 2 failures within the timeout period cause an executor to be blacklisted, but + // if task failures are spaced out by more than the timeout period, the first failure is timed + // out, and the executor isn't blacklisted. + var stageId = 0 + + def failOneTaskInTaskSet(exec: String): Unit = { + val taskSetBlacklist = createTaskSetBlacklist(stageId = stageId) + taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0) + blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) + stageId += 1 + } + + failOneTaskInTaskSet(exec = "1") + // We have one sporadic failure on exec 2, but that's it. Later checks ensure that we never + // blacklist executor 2 despite this one failure. + failOneTaskInTaskSet(exec = "2") + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + assert(blacklist.nextExpiryTime === Long.MaxValue) + + // We advance the clock past the expiry time. + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + val t0 = clock.getTimeMillis() + blacklist.applyBlacklistTimeout() + assert(blacklist.nextExpiryTime === Long.MaxValue) + failOneTaskInTaskSet(exec = "1") + + // Because the 2nd failure on executor 1 happened past the expiry time, nothing should have been + // blacklisted. + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + + // Now we add one more failure, within the timeout, and it should be counted. + clock.setTime(t0 + blacklist.BLACKLIST_TIMEOUT_MILLIS - 1) + val t1 = clock.getTimeMillis() + failOneTaskInTaskSet(exec = "1") + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(t1, "1", 2)) + assert(blacklist.nextExpiryTime === t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + // Add failures on executor 3, make sure it gets put on the blacklist. + clock.setTime(t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS - 1) + val t2 = clock.getTimeMillis() + failOneTaskInTaskSet(exec = "3") + failOneTaskInTaskSet(exec = "3") + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "3")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(t2, "3", 2)) + assert(blacklist.nextExpiryTime === t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + // Now we go past the timeout for executor 1, so it should be dropped from the blacklist. + clock.setTime(t1 + blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("3")) + verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(clock.getTimeMillis(), "1")) + assert(blacklist.nextExpiryTime === t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + // Make sure that we update correctly when we go from having blacklisted executors to + // just having tasks with timeouts. + clock.setTime(t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS - 1) + failOneTaskInTaskSet(exec = "4") + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("3")) + assert(blacklist.nextExpiryTime === t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS) + + clock.setTime(t2 + blacklist.BLACKLIST_TIMEOUT_MILLIS + 1) + blacklist.applyBlacklistTimeout() + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + verify(listenerBusMock).post(SparkListenerExecutorUnblacklisted(clock.getTimeMillis(), "3")) + // we've got one task failure still, but we don't bother setting nextExpiryTime to it, to + // avoid wasting time checking for expiry of individual task failures. + assert(blacklist.nextExpiryTime === Long.MaxValue) + } + + test("task failure timeout works as expected for long-running tasksets") { + // This ensures that we don't trigger spurious blacklisting for long tasksets, when the taskset + // finishes long after the task failures. We create two tasksets, each with one failure. + // Individually they shouldn't cause any blacklisting since there is only one failure. + // Furthermore, we space the failures out so far that even when both tasksets have completed, + // we still don't trigger any blacklisting. + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) + // Taskset1 has one failure immediately + taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0) + // Then we have a *long* delay, much longer than the timeout, before any other failures or + // taskset completion + clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS * 5) + // After the long delay, we have one failure on taskset 2, on the same executor + taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0) + // Finally, we complete both tasksets. Its important here to complete taskset2 *first*. We + // want to make sure that when taskset 1 finishes, even though we've now got two task failures, + // we realize that the task failure we just added was well before the timeout. + clock.advance(1) + blacklist.updateBlacklistForSuccessfulTaskSet(stageId = 2, 0, taskSetBlacklist2.execToFailures) + clock.advance(1) + blacklist.updateBlacklistForSuccessfulTaskSet(stageId = 1, 0, taskSetBlacklist1.execToFailures) + + // Make sure nothing was blacklisted + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) + } + + test("only blacklist nodes for the application when enough executors have failed on that " + + "specific host") { + // we blacklist executors on two different hosts -- make sure that doesn't lead to any + // node blacklisting + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 2)) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) + taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + blacklist.updateBlacklistForSuccessfulTaskSet(1, 0, taskSetBlacklist1.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 2)) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) + + // Finally, blacklist another executor on the same node as the original blacklisted executor, + // and make sure this time we *do* blacklist the node. + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 0) + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 1) + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) + assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) + verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2)) + assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set("hostA")) + verify(listenerBusMock).post(SparkListenerNodeBlacklisted(0, "hostA", 2)) + } test("blacklist still respects legacy configs") { val conf = new SparkConf().setMaster("local") @@ -35,7 +416,7 @@ class BlacklistTrackerSuite extends SparkFunSuite { // if you explicitly set the legacy conf to 0, that also would disable blacklisting conf.set(config.BLACKLIST_LEGACY_TIMEOUT_CONF, 0L) assert(!BlacklistTracker.isBlacklistEnabled(conf)) - // but again, the new conf takes precendence + // but again, the new conf takes precedence conf.set(config.BLACKLIST_ENABLED, true) assert(BlacklistTracker.isBlacklistEnabled(conf)) assert(1000 === BlacklistTracker.getBlacklistTimeout(conf)) @@ -68,6 +449,8 @@ class BlacklistTrackerSuite extends SparkFunSuite { config.MAX_TASK_ATTEMPTS_PER_NODE, config.MAX_FAILURES_PER_EXEC_STAGE, config.MAX_FAILED_EXEC_PER_NODE_STAGE, + config.MAX_FAILURES_PER_EXEC, + config.MAX_FAILED_EXEC_PER_NODE, config.BLACKLIST_TIMEOUT_CONF ).foreach { config => conf.set(config.key, "0") @@ -78,4 +461,72 @@ class BlacklistTrackerSuite extends SparkFunSuite { conf.remove(config) } } + + test("blacklisting kills executors, configured by BLACKLIST_KILL_ENABLED") { + val allocationClientMock = mock[ExecutorAllocationClient] + when(allocationClientMock.killExecutors(any(), any(), any())).thenReturn(Seq("called")) + when(allocationClientMock.killExecutorsOnHost("hostA")).thenAnswer(new Answer[Boolean] { + // To avoid a race between blacklisting and killing, it is important that the nodeBlacklist + // is updated before we ask the executor allocation client to kill all the executors + // on a particular host. + override def answer(invocation: InvocationOnMock): Boolean = { + if (blacklist.nodeBlacklist.contains("hostA") == false) { + throw new IllegalStateException("hostA should be on the blacklist") + } + true + } + }) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + + // Disable auto-kill. Blacklist an executor and make sure killExecutors is not called. + conf.set(config.BLACKLIST_KILL_ENABLED, false) + + val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) + // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole + // application. + (0 until 4).foreach { partition => + taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) + + verify(allocationClientMock, never).killExecutor(any()) + + val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) + // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole + // application. Since that's the second executor that is blacklisted on the same node, we also + // blacklist that node. + (0 until 4).foreach { partition => + taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) + + verify(allocationClientMock, never).killExecutors(any(), any(), any()) + verify(allocationClientMock, never).killExecutorsOnHost(any()) + + // Enable auto-kill. Blacklist an executor and make sure killExecutors is called. + conf.set(config.BLACKLIST_KILL_ENABLED, true) + blacklist = new BlacklistTracker(listenerBusMock, conf, Some(allocationClientMock), clock) + + val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) + // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole + // application. + (0 until 4).foreach { partition => + taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) + + verify(allocationClientMock).killExecutors(Seq("1"), true, true) + + val taskSetBlacklist3 = createTaskSetBlacklist(stageId = 1) + // Fail 4 tasks in one task set on executor 2, so that executor gets blacklisted for the whole + // application. Since that's the second executor that is blacklisted on the same node, we also + // blacklist that node. + (0 until 4).foreach { partition => + taskSetBlacklist3.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + } + blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) + + verify(allocationClientMock).killExecutors(Seq("2"), true, true) + verify(allocationClientMock).killExecutorsOnHost("hostA") + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index bec95d13d193a..a10941b579fe2 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -110,8 +110,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val cancelledStages = new HashSet[Int]() val taskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start() = {} override def stop() = {} override def executorHeartbeatReceived( @@ -126,6 +126,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { cancelledStages += stageId } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -329,7 +331,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou /** Sends JobCancelled to the DAG scheduler. */ private def cancel(jobId: Int) { - runEvent(JobCancelled(jobId)) + runEvent(JobCancelled(jobId, None)) } test("[SPARK-3353] parent stage should have lower stage id") { @@ -542,8 +544,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // make sure that the DAGScheduler doesn't crash when the TaskScheduler // doesn't implement killTask() val noKillTaskScheduler = new TaskScheduler() { - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = { @@ -552,6 +554,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou override def cancelTasks(stageId: Int, interruptThread: Boolean) { throw new UnsupportedOperationException } + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( @@ -801,7 +807,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0, 1)) - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) @@ -813,7 +819,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // map output, for the next iteration through the loop scheduler.resubmitFailedStages() - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + if (attempt < scheduler.maxConsecutiveStageAttempts - 1) { assert(scheduler.runningStages.nonEmpty) assert(!ended) } else { @@ -847,11 +853,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // In the first two iterations, Stage 0 succeeds and stage 1 fails. In the next two iterations, // stage 2 fails. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts) { // Complete all the tasks for the current attempt of stage 0 successfully completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) - if (attempt < Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2) { + if (attempt < scheduler.maxConsecutiveStageAttempts / 2) { // Now we should have a new taskSet, for a new attempt of stage 1. // Fail all these tasks with FetchFailure completeNextStageWithFetchFailure(1, attempt, shuffleDepOne) @@ -859,8 +865,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, attempt, numShufflePartitions = 1) // Fail stage 2 - completeNextStageWithFetchFailure(2, attempt - Stage.MAX_CONSECUTIVE_FETCH_FAILURES / 2, - shuffleDepTwo) + completeNextStageWithFetchFailure(2, + attempt - scheduler.maxConsecutiveStageAttempts / 2, shuffleDepTwo) } // this will trigger a resubmission of stage 0, since we've lost some of its @@ -872,7 +878,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou completeShuffleMapStageSuccessfully(1, 4, numShufflePartitions = 1) // Succeed stage2 with a "42" - completeNextResultStageWithSuccess(2, Stage.MAX_CONSECUTIVE_FETCH_FAILURES/2) + completeNextResultStageWithSuccess(2, scheduler.maxConsecutiveStageAttempts / 2) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -895,7 +901,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou submit(finalRdd, Array(0)) // First, execute stages 0 and 1, failing stage 1 up to MAX-1 times. - for (attempt <- 0 until Stage.MAX_CONSECUTIVE_FETCH_FAILURES - 1) { + for (attempt <- 0 until scheduler.maxConsecutiveStageAttempts - 1) { // Make each task in stage 0 success completeShuffleMapStageSuccessfully(0, attempt, numShufflePartitions = 2) @@ -1569,24 +1575,45 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } - test("run trivial shuffle with out-of-band failure and retry") { + /** + * In this test, we run a map stage where one of the executors fails but we still receive a + * "zombie" complete message from a task that ran on that executor. We want to make sure the + * stage is resubmitted so that the task that ran on the failed executor is re-executed, and + * that the stage is only marked as finished once that task completes. + */ + test("run trivial shuffle with out-of-band executor failure and retry") { val shuffleMapRdd = new MyRDD(sc, 2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) val shuffleId = shuffleDep.shuffleId val reduceRdd = new MyRDD(sc, 1, List(shuffleDep), tracker = mapOutputTracker) submit(reduceRdd, Array(0)) - // blockManagerMaster.removeExecutor("exec-hostA") - // pretend we were told hostA went away + // Tell the DAGScheduler that hostA was lost. runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) - // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks - // rather than marking it is as failed and waiting. complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) + + // At this point, no more tasks are running for the stage (and the TaskSetManager considers the + // stage complete), but the tasks that ran on HostA need to be re-run, so the DAGScheduler + // should re-submit the stage with one task (the task that originally ran on HostA). + assert(taskSets.size === 2) + assert(taskSets(1).tasks.size === 1) + + // Make sure that the stage that was re-submitted was the ShuffleMapStage (not the reduce + // stage, which shouldn't be run until all of the tasks in the ShuffleMapStage complete on + // alive executors). + assert(taskSets(1).tasks(0).isInstanceOf[ShuffleMapTask]) + // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + + // Make sure that the reduce stage was now submitted. + assert(taskSets.size === 3) + assert(taskSets(2).tasks(0).isInstanceOf[ResultTask[_, _]]) + + // Complete the reduce stage. complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -1819,7 +1846,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === HashSet(makeBlockManagerId("hostA"))) - // Reducer should run where RDD 2 has preferences, even though though it also has a shuffle dep + // Reducer should run where RDD 2 has preferences, even though it also has a shuffle dep val reduceTaskSet = taskSets(1) assertLocations(reduceTaskSet, Seq(Seq("hostB"))) complete(reduceTaskSet, Seq((Success, 42))) @@ -2031,6 +2058,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou * In this test, we run a map stage where one of the executors fails but we still receive a * "zombie" complete message from that executor. We want to make sure the stage is not reported * as done until all tasks have completed. + * + * Most of the functionality in this test is tested in "run trivial shuffle with out-of-band + * executor failure and retry". However, that test uses ShuffleMapStages that are followed by + * a ResultStage, whereas in this test, the ShuffleMapStage is tested in isolation, without a + * ResultStage after it. */ test("map stage submission with executor failure late map task completions") { val shuffleMapRdd = new MyRDD(sc, 3, Nil) @@ -2042,7 +2074,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou runEvent(makeCompletionEvent(oldTaskSet.tasks(0), Success, makeMapStatus("hostA", 2))) assert(results.size === 0) // Map stage job should not be complete yet - // Pretend host A was lost + // Pretend host A was lost. This will cause the TaskSetManager to resubmit task 0, because it + // completed on hostA. val oldEpoch = mapOutputTracker.getEpoch runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) val newEpoch = mapOutputTracker.getEpoch @@ -2054,13 +2087,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // A completion from another task should work because it's a non-failed host runEvent(makeCompletionEvent(oldTaskSet.tasks(2), Success, makeMapStatus("hostB", 2))) - assert(results.size === 0) // Map stage job should not be complete yet + + // At this point, no more tasks are running for the stage (and the TaskSetManager considers + // the stage complete), but the task that ran on hostA needs to be re-run, so the map stage + // shouldn't be marked as complete, and the DAGScheduler should re-submit the stage. + assert(results.size === 0) + assert(taskSets.size === 2) // Now complete tasks in the second task set val newTaskSet = taskSets(1) - assert(newTaskSet.tasks.size === 2) // Both tasks 0 and 1 were on on hostA + // 2 tasks should have been re-submitted, for tasks 0 and 1 (which ran on hostA). + assert(newTaskSet.tasks.size === 2) + // Complete task 0 from the original task set (i.e., not hte one that's currently active). + // This should still be counted towards the job being complete (but there's still one + // outstanding task). runEvent(makeCompletionEvent(newTaskSet.tasks(0), Success, makeMapStatus("hostB", 2))) - assert(results.size === 0) // Map stage job should not be complete yet + assert(results.size === 0) + + // Complete the final task, from the currently active task set. There's still one + // running task, task 0 in the currently active stage attempt, but the success of task 0 means + // the DAGScheduler can mark the stage as finished. runEvent(makeCompletionEvent(newTaskSet.tasks(1), Success, makeMapStatus("hostB", 2))) assert(results.size === 1) // Map stage job should now finally be complete assertDataStructuresEmpty() @@ -2076,7 +2122,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } /** - * Checks the DAGScheduler's internal logic for traversing a RDD DAG by making sure that + * Checks the DAGScheduler's internal logic for traversing an RDD DAG by making sure that * getShuffleDependencies correctly returns the direct shuffle dependencies of a particular * RDD. The test creates the following RDD graph (where n denotes a narrow dependency and s * denotes a shuffle dependency): @@ -2161,6 +2207,76 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou } } + test("[SPARK-19263] DAGScheduler should not submit multiple active tasksets," + + " even with late completions from earlier stage attempts") { + // Create 3 RDDs with shuffle dependencies on each other: rddA <--- rddB <--- rddC + val rddA = new MyRDD(sc, 2, Nil) + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(2)) + val shuffleIdA = shuffleDepA.shuffleId + + val rddB = new MyRDD(sc, 2, List(shuffleDepA), tracker = mapOutputTracker) + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(2)) + + val rddC = new MyRDD(sc, 2, List(shuffleDepB), tracker = mapOutputTracker) + + submit(rddC, Array(0, 1)) + + // Complete both tasks in rddA. + assert(taskSets(0).stageId === 0 && taskSets(0).stageAttemptId === 0) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostA", 2)))) + + // Fetch failed for task(stageId=1, stageAttemptId=0, partitionId=0) running on hostA + // and task(stageId=1, stageAttemptId=0, partitionId=1) is still running. + assert(taskSets(1).stageId === 1 && taskSets(1).stageAttemptId === 0) + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleIdA, 0, 0, + "Fetch failure of task: stageId=1, stageAttempt=0, partitionId=0"), + result = null)) + + // Both original tasks in rddA should be marked as failed, because they ran on the + // failed hostA, so both should be resubmitted. Complete them on hostB successfully. + scheduler.resubmitFailedStages() + assert(taskSets(2).stageId === 0 && taskSets(2).stageAttemptId === 1 + && taskSets(2).tasks.size === 2) + complete(taskSets(2), Seq( + (Success, makeMapStatus("hostB", 2)), + (Success, makeMapStatus("hostB", 2)))) + + // Complete task(stageId=1, stageAttemptId=0, partitionId=1) running on failed hostA + // successfully. The success should be ignored because the task started before the + // executor failed, so the output may have been lost. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), Success, makeMapStatus("hostA", 2))) + + // Both tasks in rddB should be resubmitted, because none of them has succeeded truely. + // Complete the task(stageId=1, stageAttemptId=1, partitionId=0) successfully. + // Task(stageId=1, stageAttemptId=1, partitionId=1) of this new active stage attempt + // is still running. + assert(taskSets(3).stageId === 1 && taskSets(3).stageAttemptId === 1 + && taskSets(3).tasks.size === 2) + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), Success, makeMapStatus("hostB", 2))) + + // There should be no new attempt of stage submitted, + // because task(stageId=1, stageAttempt=1, partitionId=1) is still running in + // the current attempt (and hasn't completed successfully in any earlier attempts). + assert(taskSets.size === 4) + + // Complete task(stageId=1, stageAttempt=1, partitionId=1) successfully. + runEvent(makeCompletionEvent( + taskSets(3).tasks(1), Success, makeMapStatus("hostB", 2))) + + // Now the ResultStage should be submitted, because all of the tasks of rddB have + // completed successfully on alive executors. + assert(taskSets.size === 5 && taskSets(4).tasks(0).isInstanceOf[ResultTask[_, _]]) + complete(taskSets(4), Seq( + (Success, 1), + (Success, 1))) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 7f4859206e257..4c3d0b102152c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -95,6 +95,18 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } } + test("Event logging with password redaction") { + val key = "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" + val secretPassword = "secret_password" + val conf = getLoggingConf(testDirPath, None) + .set(key, secretPassword) + val eventLogger = new EventLoggingListener("test", None, testDirPath.toUri(), conf) + val envDetails = SparkEnv.environmentDetails(conf, "FIFO", Seq.empty, Seq.empty) + val event = SparkListenerEnvironmentUpdate(envDetails) + val redactedProps = eventLogger.redactEvent(event).environmentDetails("Spark Properties").toMap + assert(redactedProps(key) == "*********(redacted)") + } + test("Log overwriting") { val logUri = EventLoggingListener.getLogPath(testDir.toURI, "test", None) val logPath = new URI(logUri).getPath @@ -107,19 +119,20 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit } test("Event log name") { + val baseDirUri = Utils.resolveURI("/base-dir") // without compression - assert(s"file:/base-dir/app1" === EventLoggingListener.getLogPath( - Utils.resolveURI("/base-dir"), "app1", None)) + assert(s"${baseDirUri.toString}/app1" === EventLoggingListener.getLogPath( + baseDirUri, "app1", None)) // with compression - assert(s"file:/base-dir/app1.lzf" === - EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), "app1", None, Some("lzf"))) + assert(s"${baseDirUri.toString}/app1.lzf" === + EventLoggingListener.getLogPath(baseDirUri, "app1", None, Some("lzf"))) // illegal characters in app ID - assert(s"file:/base-dir/a-fine-mind_dollar_bills__1" === - EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + assert(s"${baseDirUri.toString}/a-fine-mind_dollar_bills__1" === + EventLoggingListener.getLogPath(baseDirUri, "a fine:mind$dollar{bills}.1", None)) // illegal characters in app ID with compression - assert(s"file:/base-dir/a-fine-mind_dollar_bills__1.lz4" === - EventLoggingListener.getLogPath(Utils.resolveURI("/base-dir"), + assert(s"${baseDirUri.toString}/a-fine-mind_dollar_bills__1.lz4" === + EventLoggingListener.getLogPath(baseDirUri, "a fine:mind$dollar{bills}.1", None, Some("lz4"))) } @@ -202,8 +215,6 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit // Make sure expected events exist in the log file. val logData = EventLoggingListener.openEventLog(new Path(eventLogger.logPath), fileSystem) - val logStart = SparkListenerLogStart(SPARK_VERSION) - val lines = readLines(logData) val eventSet = mutable.Set( SparkListenerApplicationStart, SparkListenerBlockManagerAdded, @@ -216,19 +227,25 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerTaskStart, SparkListenerTaskEnd, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) - lines.foreach { line => - eventSet.foreach { event => - if (line.contains(event)) { - val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line)) - val eventType = Utils.getFormattedClassName(parsedEvent) - if (eventType == event) { - eventSet.remove(event) + Utils.tryWithSafeFinally { + val logStart = SparkListenerLogStart(SPARK_VERSION) + val lines = readLines(logData) + lines.foreach { line => + eventSet.foreach { event => + if (line.contains(event)) { + val parsedEvent = JsonProtocol.sparkEventFromJson(parse(line)) + val eventType = Utils.getFormattedClassName(parsedEvent) + if (eventType == event) { + eventSet.remove(event) + } } } } + assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) + assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) + } { + logData.close() } - assert(JsonProtocol.sparkEventFromJson(parse(lines(0))) === logStart) - assert(eventSet.isEmpty, "The following events are missing: " + eventSet.toSeq) } private def readLines(in: InputStream): Seq[String] = { @@ -273,7 +290,7 @@ object EventLoggingListenerSuite { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") conf.set("spark.eventLog.testing", "true") - conf.set("spark.eventLog.dir", logDir.toString) + conf.set("spark.eventLog.dir", logDir.toUri.toString) compressionCodec.foreach { codec => conf.set("spark.eventLog.compress", "true") conf.set("spark.io.compression.codec", codec) diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index e87cebf0cf358..ba56af8215cd7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -73,12 +73,14 @@ private class DummySchedulerBackend extends SchedulerBackend { private class DummyTaskScheduler extends TaskScheduler { var initialized = false - override def rootPool: Pool = null - override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def schedulingMode: SchedulingMode = SchedulingMode.FIFO + override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) override def start(): Unit = {} override def stop(): Unit = {} override def submitTasks(taskSet: TaskSet): Unit = {} override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} + override def killTaskAttempt( + taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index a757041299411..fe6de2bd98850 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -17,12 +17,20 @@ package org.apache.spark.scheduler +import java.util.Properties + +import org.apache.spark.SparkEnv import org.apache.spark.TaskContext +import org.apache.spark.executor.TaskMetrics class FakeTask( stageId: Int, partitionId: Int, - prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) { + prefLocs: Seq[TaskLocation] = Nil, + serializedTaskMetrics: Array[Byte] = + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) { + override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 83288db92bb43..e51e6a0d3ff6b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -31,6 +31,7 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.internal.io.SparkHadoopWriter import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.{ThreadUtils, Utils} @@ -158,10 +159,9 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { 0 until rdd.partitions.size, resultHandler, () => Unit) // It's an error if the job completes successfully even though no committer was authorized, // so throw an exception if the job was allowed to complete. - val e = intercept[SparkException] { + intercept[TimeoutException] { ThreadUtils.awaitResult(futureAction, 5 seconds) } - assert(e.getCause.isInstanceOf[TimeoutException]) assert(tempDir.list().size === 0) } @@ -176,13 +176,13 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter)) // The non-authorized committer fails outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test")) // New tasks should still not be able to commit because the authorized committer has not failed assert( !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1)) // The authorized committer now fails, clearing the lock outputCommitCoordinator.taskCompleted( - stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled) + stage, partition, attemptNumber = authorizedCommitter, reason = TaskKilled("test")) // A new task should now be allowed to become the authorized committer assert( outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2)) @@ -190,6 +190,23 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { assert( !outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3)) } + + test("Duplicate calls to canCommit from the authorized committer gets idempotent responses.") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, + 0 until rdd.partitions.size) + } + + test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { + val stage: Int = 1 + val partition: Int = 1 + val failedAttempt: Int = 0 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + } } /** @@ -222,6 +239,16 @@ private case class OutputCommitFunctions(tempDirPath: String) { if (ctx.attemptNumber == 0) failingOutputCommitter else successfulOutputCommitter) } + // Receiver should be idempotent for AskPermissionToCommitOutput + def callCanCommitMultipleTimes(iter: Iterator[Int]): Unit = { + val ctx = TaskContext.get() + val canCommit1 = SparkEnv.get.outputCommitCoordinator + .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) + val canCommit2 = SparkEnv.get.outputCommitCoordinator + .canCommit(ctx.stageId(), ctx.partitionId(), ctx.attemptNumber()) + assert(canCommit1 && canCommit2) + } + private def runCommitWithProvidedCommitter( ctx: TaskContext, iter: Iterator[Int], diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 00e1c447ccbef..4901062a78553 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.util.Properties import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.SchedulingMode._ /** * Tests that pools and the associated scheduling algorithms for FIFO and fair scheduling work @@ -27,6 +28,11 @@ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSui */ class PoolSuite extends SparkFunSuite with LocalSparkContext { + val LOCAL = "local" + val APP_NAME = "PoolSuite" + val SCHEDULER_ALLOCATION_FILE_PROPERTY = "spark.scheduler.allocation.file" + val TEST_POOL = "testPool" + def createTaskSetManager(stageId: Int, numTasks: Int, taskScheduler: TaskSchedulerImpl) : TaskSetManager = { val tasks = Array.tabulate[Task[_]](numTasks) { i => @@ -35,7 +41,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null), 0) } - def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int) { + def scheduleTaskAndVerifyId(taskId: Int, rootPool: Pool, expectedStageId: Int): Unit = { val taskSetQueue = rootPool.getSortedTaskSetQueue val nextTaskSetToSchedule = taskSetQueue.find(t => (t.runningTasks + t.tasksSuccessful) < t.numTasks) @@ -45,12 +51,11 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { } test("FIFO Scheduler Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") + sc = new SparkContext(LOCAL, APP_NAME) val taskScheduler = new TaskSchedulerImpl(sc) - val rootPool = new Pool("", SchedulingMode.FIFO, 0, 0) + val rootPool = new Pool("", FIFO, 0, 0) val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) - schedulableBuilder.buildPools() val taskSetManager0 = createTaskSetManager(0, 2, taskScheduler) val taskSetManager1 = createTaskSetManager(1, 2, taskScheduler) @@ -74,30 +79,24 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { */ test("Fair Scheduler Test") { val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile() - val conf = new SparkConf().set("spark.scheduler.allocation.file", xmlPath) - sc = new SparkContext("local", "TaskSchedulerImplSuite", conf) + val conf = new SparkConf().set(SCHEDULER_ALLOCATION_FILE_PROPERTY, xmlPath) + sc = new SparkContext(LOCAL, APP_NAME, conf) val taskScheduler = new TaskSchedulerImpl(sc) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) + val rootPool = new Pool("", FAIR, 0, 0) val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) schedulableBuilder.buildPools() // Ensure that the XML file was read in correctly. - assert(rootPool.getSchedulableByName("default") != null) - assert(rootPool.getSchedulableByName("1") != null) - assert(rootPool.getSchedulableByName("2") != null) - assert(rootPool.getSchedulableByName("3") != null) - assert(rootPool.getSchedulableByName("1").minShare === 2) - assert(rootPool.getSchedulableByName("1").weight === 1) - assert(rootPool.getSchedulableByName("2").minShare === 3) - assert(rootPool.getSchedulableByName("2").weight === 1) - assert(rootPool.getSchedulableByName("3").minShare === 0) - assert(rootPool.getSchedulableByName("3").weight === 1) + verifyPool(rootPool, schedulableBuilder.DEFAULT_POOL_NAME, 0, 1, FIFO) + verifyPool(rootPool, "1", 2, 1, FIFO) + verifyPool(rootPool, "2", 3, 1, FIFO) + verifyPool(rootPool, "3", 0, 1, FIFO) val properties1 = new Properties() - properties1.setProperty("spark.scheduler.pool", "1") + properties1.setProperty(schedulableBuilder.FAIR_SCHEDULER_PROPERTIES, "1") val properties2 = new Properties() - properties2.setProperty("spark.scheduler.pool", "2") + properties2.setProperty(schedulableBuilder.FAIR_SCHEDULER_PROPERTIES, "2") val taskSetManager10 = createTaskSetManager(0, 1, taskScheduler) val taskSetManager11 = createTaskSetManager(1, 1, taskScheduler) @@ -134,22 +133,22 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { } test("Nested Pool Test") { - sc = new SparkContext("local", "TaskSchedulerImplSuite") + sc = new SparkContext(LOCAL, APP_NAME) val taskScheduler = new TaskSchedulerImpl(sc) - val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0) - val pool0 = new Pool("0", SchedulingMode.FAIR, 3, 1) - val pool1 = new Pool("1", SchedulingMode.FAIR, 4, 1) + val rootPool = new Pool("", FAIR, 0, 0) + val pool0 = new Pool("0", FAIR, 3, 1) + val pool1 = new Pool("1", FAIR, 4, 1) rootPool.addSchedulable(pool0) rootPool.addSchedulable(pool1) - val pool00 = new Pool("00", SchedulingMode.FAIR, 2, 2) - val pool01 = new Pool("01", SchedulingMode.FAIR, 1, 1) + val pool00 = new Pool("00", FAIR, 2, 2) + val pool01 = new Pool("01", FAIR, 1, 1) pool0.addSchedulable(pool00) pool0.addSchedulable(pool01) - val pool10 = new Pool("10", SchedulingMode.FAIR, 2, 2) - val pool11 = new Pool("11", SchedulingMode.FAIR, 2, 1) + val pool10 = new Pool("10", FAIR, 2, 2) + val pool11 = new Pool("11", FAIR, 2, 1) pool1.addSchedulable(pool10) pool1.addSchedulable(pool11) @@ -178,4 +177,127 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { scheduleTaskAndVerifyId(2, rootPool, 6) scheduleTaskAndVerifyId(3, rootPool, 2) } + + test("SPARK-17663: FairSchedulableBuilder sets default values for blank or invalid datas") { + val xmlPath = getClass.getClassLoader.getResource("fairscheduler-with-invalid-data.xml") + .getFile() + val conf = new SparkConf().set(SCHEDULER_ALLOCATION_FILE_PROPERTY, xmlPath) + + val rootPool = new Pool("", FAIR, 0, 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, conf) + schedulableBuilder.buildPools() + + verifyPool(rootPool, schedulableBuilder.DEFAULT_POOL_NAME, 0, 1, FIFO) + verifyPool(rootPool, "pool_with_invalid_min_share", 0, 2, FAIR) + verifyPool(rootPool, "pool_with_invalid_weight", 1, 1, FAIR) + verifyPool(rootPool, "pool_with_invalid_scheduling_mode", 3, 2, FIFO) + verifyPool(rootPool, "pool_with_non_uppercase_scheduling_mode", 2, 1, FAIR) + verifyPool(rootPool, "pool_with_NONE_scheduling_mode", 1, 2, FIFO) + verifyPool(rootPool, "pool_with_whitespace_min_share", 0, 2, FAIR) + verifyPool(rootPool, "pool_with_whitespace_weight", 1, 1, FAIR) + verifyPool(rootPool, "pool_with_whitespace_scheduling_mode", 3, 2, FIFO) + verifyPool(rootPool, "pool_with_empty_min_share", 0, 3, FAIR) + verifyPool(rootPool, "pool_with_empty_weight", 2, 1, FAIR) + verifyPool(rootPool, "pool_with_empty_scheduling_mode", 2, 2, FIFO) + verifyPool(rootPool, "pool_with_surrounded_whitespace", 3, 2, FAIR) + } + + /** + * spark.scheduler.pool property should be ignored for the FIFO scheduler, + * because pools are only needed for fair scheduling. + */ + test("FIFO scheduler uses root pool and not spark.scheduler.pool property") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FIFO, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FIFOSchedulableBuilder(rootPool) + + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty("spark.scheduler.pool", TEST_POOL) + + // When FIFO Scheduler is used and task sets are submitted, they should be added to + // the root pool, and no additional pools should be created + // (even though there's a configured default pool). + schedulableBuilder.addTaskSetManager(taskSetManager0, properties) + schedulableBuilder.addTaskSetManager(taskSetManager1, properties) + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + assert(rootPool.schedulableQueue.size === 2) + assert(rootPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + assert(rootPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler uses default pool when spark.scheduler.pool property is not set") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + // Submit a new task set manager with pool properties set to null. This should result + // in the task set manager getting added to the default pool. + val taskSetManager0 = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager0, null) + + val defaultPool = rootPool.getSchedulableByName(schedulableBuilder.DEFAULT_POOL_NAME) + assert(defaultPool !== null) + assert(defaultPool.schedulableQueue.size === 1) + assert(defaultPool.getSchedulableByName(taskSetManager0.name) === taskSetManager0) + + // When a task set manager is submitted with spark.scheduler.pool unset, it should be added to + // the default pool (as above). + val taskSetManager1 = createTaskSetManager(stageId = 1, numTasks = 1, taskScheduler) + schedulableBuilder.addTaskSetManager(taskSetManager1, new Properties()) + + assert(defaultPool.schedulableQueue.size === 2) + assert(defaultPool.getSchedulableByName(taskSetManager1.name) === taskSetManager1) + } + + test("FAIR Scheduler creates a new pool when spark.scheduler.pool property points to " + + "a non-existent pool") { + sc = new SparkContext("local", "PoolSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + + val rootPool = new Pool("", SchedulingMode.FAIR, initMinShare = 0, initWeight = 0) + val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf) + schedulableBuilder.buildPools() + + assert(rootPool.getSchedulableByName(TEST_POOL) === null) + + val taskSetManager = createTaskSetManager(stageId = 0, numTasks = 1, taskScheduler) + + val properties = new Properties() + properties.setProperty(schedulableBuilder.FAIR_SCHEDULER_PROPERTIES, TEST_POOL) + + // The fair scheduler should create a new pool with default values when spark.scheduler.pool + // points to a pool that doesn't exist yet (this can happen when the file that pools are read + // from isn't set, or when that file doesn't contain the pool name specified + // by spark.scheduler.pool). + schedulableBuilder.addTaskSetManager(taskSetManager, properties) + + verifyPool(rootPool, TEST_POOL, schedulableBuilder.DEFAULT_MINIMUM_SHARE, + schedulableBuilder.DEFAULT_WEIGHT, schedulableBuilder.DEFAULT_SCHEDULING_MODE) + val testPool = rootPool.getSchedulableByName(TEST_POOL) + assert(testPool.getSchedulableByName(taskSetManager.name) === taskSetManager) + } + + test("Pool should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + new Pool("TestPool", SchedulingMode.NONE, 0, 1) + } + } + + private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int, + expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = { + val selectedPool = rootPool.getSchedulableByName(poolName) + assert(selectedPool !== null) + assert(selectedPool.minShare === expectedInitMinShare) + assert(selectedPool.weight === expectedInitWeight) + assert(selectedPool.schedulingMode === expectedSchedulingMode) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index c28aa06623a60..8300607ea888b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -28,6 +28,8 @@ import scala.reflect.ClassTag import org.scalactic.TripleEquals import org.scalatest.Assertions.AssertionsHelper +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.TaskState._ @@ -93,12 +95,12 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } /** - * A map from partition -> results for all tasks of a job when you call this test framework's + * A map from partition to results for all tasks of a job when you call this test framework's * [[submit]] method. Two important considerations: * * 1. If there is a job failure, results may or may not be empty. If any tasks succeed before * the job has failed, they will get included in `results`. Instead, check for job failure by - * checking [[failure]]. (Also see [[assertDataStructuresEmpty()]]) + * checking [[failure]]. (Also see `assertDataStructuresEmpty()`) * * 2. This only gets cleared between tests. So you'll need to do special handling if you submit * more than one job in one test. @@ -157,8 +159,16 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa } // When a job fails, we terminate before waiting for all the task end events to come in, // so there might still be a running task set. So we only check these conditions - // when the job succeeds - assert(taskScheduler.runningTaskSets.isEmpty) + // when the job succeeds. + // When the final task of a taskset completes, we post + // the event to the DAGScheduler event loop before we finish processing in the taskscheduler + // thread. It's possible the DAGScheduler thread processes the event, finishes the job, + // and notifies the job waiter before our original thread in the task scheduler finishes + // handling the event and marks the taskset as complete. So its ok if we need to wait a + // *little* bit longer for the original taskscheduler thread to finish up to deal w/ the race. + eventually(timeout(1 second), interval(10 millis)) { + assert(taskScheduler.runningTaskSets.isEmpty) + } assert(!backend.hasTasks) } else { assert(failure != null) @@ -381,17 +391,17 @@ private[spark] abstract class MockBackend( * scheduling. */ override def reviveOffers(): Unit = { - val newTaskDescriptions = taskScheduler.resourceOffers(generateOffers()).flatten - // get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual - // tests from introducing a race if they need it - val newTasks = taskScheduler.synchronized { - newTaskDescriptions.map { taskDescription => + // Need a lock on the entire scheduler to protect freeCores -- otherwise, multiple threads + // may make offers at the same time, though they are using the same set of freeCores. + taskScheduler.synchronized { + val newTaskDescriptions = taskScheduler.resourceOffers(generateOffers()).flatten + // get the task now, since that requires a lock on TaskSchedulerImpl, to prevent individual + // tests from introducing a race if they need it. + val newTasks = newTaskDescriptions.map { taskDescription => val taskSet = taskScheduler.taskIdToTaskSetManager(taskDescription.taskId).taskSet val task = taskSet.tasks(taskDescription.index) (taskDescription, task) } - } - synchronized { newTasks.foreach { case (taskDescription, _) => executorIdToExecutor(taskDescription.executorId).freeCores -= taskScheduler.CPUS_PER_TASK } @@ -400,7 +410,8 @@ private[spark] abstract class MockBackend( } } - override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + override def killTask( + taskId: Long, executorId: String, interruptThread: Boolean, reason: String): Unit = { // We have to implement this b/c of SPARK-15385. // Its OK for this to be a no-op, because even if a backend does implement killTask, // it really can only be "best-effort" in any case, and the scheduler should be robust to that. diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index e8a88d4909a83..80c7e0bfee6ef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -184,7 +184,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be {1} val stageInfo2 = listener.stageInfos.keys.find(_.stageId == 1).get - stageInfo2.rddInfos.size should be {3} // ParallelCollectionRDD, FilteredRDD, MappedRDD + stageInfo2.rddInfos.size should be {3} stageInfo2.rddInfos.forall(_.numPartitions == 4) should be {true} stageInfo2.rddInfos.exists(_.name == "Deux") should be {true} listener.stageInfos.clear() @@ -229,7 +229,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } val numSlices = 16 - val d = sc.parallelize(0 to 1e3.toInt, numSlices).map(w) + val d = sc.parallelize(0 to 10000, numSlices).map(w) d.count() sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) listener.stageInfos.size should be (1) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 9eda79ace18d0..992d3396d203f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -62,7 +62,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, + closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null) } @@ -83,7 +84,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, + closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null) } @@ -98,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark context.addTaskCompletionListener(_ => throw new Exception("blah")) intercept[TaskCompletionListenerException] { - context.markTaskCompleted() + context.markTaskCompleted(None) } verify(listener, times(1)).onTaskCompletion(any()) @@ -196,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark sc = new SparkContext("local", "test") // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), @@ -226,6 +228,62 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(res === Array("testPropValue,testPropValue")) } + test("immediately call a completion listener if the context is completed") { + var invocations = 0 + val context = TaskContext.empty() + context.markTaskCompleted(None) + context.addTaskCompletionListener(_ => invocations += 1) + assert(invocations == 1) + context.markTaskCompleted(None) + assert(invocations == 1) + } + + test("immediately call a failure listener if the context has failed") { + var invocations = 0 + var lastError: Throwable = null + val error = new RuntimeException + val context = TaskContext.empty() + context.markTaskFailed(error) + context.addTaskFailureListener { (_, e) => + lastError = e + invocations += 1 + } + assert(lastError == error) + assert(invocations == 1) + context.markTaskFailed(error) + assert(lastError == error) + assert(invocations == 1) + } + + test("TaskCompletionListenerException.getMessage should include previousError") { + val listenerErrorMessage = "exception in listener" + val taskErrorMessage = "exception in task" + val e = new TaskCompletionListenerException( + Seq(listenerErrorMessage), + Some(new RuntimeException(taskErrorMessage))) + assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage)) + } + + test("all TaskCompletionListeners should be called even if some fail or a task") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener1")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskCompleted(Some(new Exception("exception in task"))) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskCompletion(any()) + + // also need to check failure in TaskCompletionListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala new file mode 100644 index 0000000000000..97487ce1d2ca8 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException} +import java.nio.ByteBuffer +import java.util.Properties + +import scala.collection.mutable.HashMap + +import org.apache.spark.SparkFunSuite + +class TaskDescriptionSuite extends SparkFunSuite { + test("encoding and then decoding a TaskDescription results in the same TaskDescription") { + val originalFiles = new HashMap[String, Long]() + originalFiles.put("fileUrl1", 1824) + originalFiles.put("fileUrl2", 2) + + val originalJars = new HashMap[String, Long]() + originalJars.put("jar1", 3) + + val originalProperties = new Properties() + originalProperties.put("property1", "18") + originalProperties.put("property2", "test value") + // SPARK-19796 -- large property values (like a large job description for a long sql query) + // can cause problems for DataOutputStream, make sure we handle correctly + val sb = new StringBuilder() + (0 to 10000).foreach(_ => sb.append("1234567890")) + val largeString = sb.toString() + originalProperties.put("property3", largeString) + // make sure we've got a good test case + intercept[UTFDataFormatException] { + val out = new DataOutputStream(new ByteArrayOutputStream()) + try { + out.writeUTF(largeString) + } finally { + out.close() + } + } + + // Create a dummy byte buffer for the task. + val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) + + val originalTaskDescription = new TaskDescription( + taskId = 1520589, + attemptNumber = 2, + executorId = "testExecutor", + name = "task for test", + index = 19, + originalFiles, + originalJars, + originalProperties, + taskBuffer + ) + + val serializedTaskDescription = TaskDescription.encode(originalTaskDescription) + val decodedTaskDescription = TaskDescription.decode(serializedTaskDescription) + + // Make sure that all of the fields in the decoded task description match the original. + assert(decodedTaskDescription.taskId === originalTaskDescription.taskId) + assert(decodedTaskDescription.attemptNumber === originalTaskDescription.attemptNumber) + assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) + assert(decodedTaskDescription.name === originalTaskDescription.name) + assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.addedFiles.equals(originalFiles)) + assert(decodedTaskDescription.addedJars.equals(originalJars)) + assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) + assert(decodedTaskDescription.serializedTask.equals(taskBuffer)) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index 9e472f900b655..3e55d399e9df9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.File +import java.io.{File, ObjectInputStream} import java.net.URL import java.nio.ByteBuffer @@ -171,7 +171,7 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local val tempDir = Utils.createTempDir() val srcDir = new File(tempDir, "repro/") srcDir.mkdirs() - val excSource = new JavaSourceFromString(new File(srcDir, "MyException").getAbsolutePath, + val excSource = new JavaSourceFromString(new File(srcDir, "MyException").toURI.getPath, """package repro; | |public class MyException extends Exception { @@ -183,9 +183,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local // ensure we reset the classloader after the test completes val originalClassLoader = Thread.currentThread.getContextClassLoader - try { + val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) + Utils.tryWithSafeFinally { // load the exception from the jar - val loader = new MutableURLClassLoader(new Array[URL](0), originalClassLoader) loader.addURL(jarFile.toURI.toURL) Thread.currentThread().setContextClassLoader(loader) val excClass: Class[_] = Utils.classForName("repro.MyException") @@ -209,8 +209,9 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local assert(expectedFailure.findFirstMatchIn(exceptionMessage).isDefined) assert(unknownFailure.findFirstMatchIn(exceptionMessage).isEmpty) - } finally { + } { Thread.currentThread.setContextClassLoader(originalClassLoader) + loader.close() } } @@ -247,5 +248,24 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local assert(resSizeAfter.exists(_.toString.toLong > 0L)) } + test("failed task is handled when error occurs deserializing the reason") { + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(Seq(1), 1).map { _ => + throw new UndeserializableException + } + val message = intercept[SparkException] { + rdd.collect() + }.getMessage + // Job failed, even though the failure reason is unknown. + val unknownFailure = """(?s).*Lost task.*: UnknownReason.*""".r + assert(unknownFailure.findFirstMatchIn(message).isDefined) + } + +} + +private class UndeserializableException extends Exception { + private def readObject(in: ObjectInputStream): Unit = { + throw new NoClassDefFoundError() + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index f5f1947661d9a..8b9d45f734cda 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -17,11 +17,19 @@ package org.apache.spark.scheduler +import java.nio.ByteBuffer + +import scala.collection.mutable.HashMap + +import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq} +import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when} import org.scalatest.BeforeAndAfterEach +import org.scalatest.mock.MockitoSugar import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.util.ManualClock class FakeSchedulerBackend extends SchedulerBackend { def start() {} @@ -31,20 +39,26 @@ class FakeSchedulerBackend extends SchedulerBackend { } class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with BeforeAndAfterEach - with Logging { + with Logging with MockitoSugar { var failedTaskSetException: Option[Throwable] = None var failedTaskSetReason: String = null var failedTaskSet = false + var blacklist: BlacklistTracker = null var taskScheduler: TaskSchedulerImpl = null var dagScheduler: DAGScheduler = null + val stageToMockTaskSetBlacklist = new HashMap[Int, TaskSetBlacklist]() + val stageToMockTaskSetManager = new HashMap[Int, TaskSetManager]() + override def beforeEach(): Unit = { super.beforeEach() failedTaskSet = false failedTaskSetException = None failedTaskSetReason = null + stageToMockTaskSetBlacklist.clear() + stageToMockTaskSetManager.clear() } override def afterEach(): Unit = { @@ -61,11 +75,34 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B def setupScheduler(confs: (String, String)*): TaskSchedulerImpl = { val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") - confs.foreach { case (k, v) => - conf.set(k, v) - } + confs.foreach { case (k, v) => conf.set(k, v) } sc = new SparkContext(conf) taskScheduler = new TaskSchedulerImpl(sc) + setupHelper() + } + + def setupSchedulerWithMockTaskSetBlacklist(): TaskSchedulerImpl = { + blacklist = mock[BlacklistTracker] + val conf = new SparkConf().setMaster("local").setAppName("TaskSchedulerImplSuite") + conf.set(config.BLACKLIST_ENABLED, true) + sc = new SparkContext(conf) + taskScheduler = + new TaskSchedulerImpl(sc, sc.conf.getInt("spark.task.maxFailures", 4), Some(blacklist)) { + override def createTaskSetManager(taskSet: TaskSet, maxFailures: Int): TaskSetManager = { + val tsm = super.createTaskSetManager(taskSet, maxFailures) + // we need to create a spied tsm just so we can set the TaskSetBlacklist + val tsmSpy = spy(tsm) + val taskSetBlacklist = mock[TaskSetBlacklist] + when(tsmSpy.taskSetBlacklistHelperOpt).thenReturn(Some(taskSetBlacklist)) + stageToMockTaskSetManager(taskSet.stageId) = tsmSpy + stageToMockTaskSetBlacklist(taskSet.stageId) = taskSetBlacklist + tsmSpy + } + } + setupHelper() + } + + def setupHelper(): TaskSchedulerImpl = { taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. dagScheduler = new DAGScheduler(sc, taskScheduler) { @@ -282,6 +319,300 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!failedTaskSet) } + test("scheduled tasks obey task and stage blacklists") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + (0 to 2).foreach {stageId => + val taskSet = FakeTask.createTaskSet(numTasks = 2, stageId = stageId, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + } + + // Setup our mock blacklist: + // * stage 0 is blacklisted on node "host1" + // * stage 1 is blacklisted on executor "executor3" + // * stage 0, partition 0 is blacklisted on executor 0 + // (mocked methods default to returning false, ie. no blacklisting) + when(stageToMockTaskSetBlacklist(0).isNodeBlacklistedForTaskSet("host1")).thenReturn(true) + when(stageToMockTaskSetBlacklist(1).isExecutorBlacklistedForTaskSet("executor3")) + .thenReturn(true) + when(stageToMockTaskSetBlacklist(0).isExecutorBlacklistedForTask("executor0", 0)) + .thenReturn(true) + + val offers = IndexedSeq( + new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1), + new WorkerOffer("executor2", "host1", 1), + new WorkerOffer("executor3", "host2", 10) + ) + val firstTaskAttempts = taskScheduler.resourceOffers(offers).flatten + // We should schedule all tasks. + assert(firstTaskAttempts.size === 6) + // Whenever we schedule a task, we must consult the node and executor blacklist. (The test + // doesn't check exactly what checks are made because the offers get shuffled.) + (0 to 2).foreach { stageId => + verify(stageToMockTaskSetBlacklist(stageId), atLeast(1)) + .isNodeBlacklistedForTaskSet(anyString()) + verify(stageToMockTaskSetBlacklist(stageId), atLeast(1)) + .isExecutorBlacklistedForTaskSet(anyString()) + } + + def tasksForStage(stageId: Int): Seq[TaskDescription] = { + firstTaskAttempts.filter{_.name.contains(s"stage $stageId")} + } + tasksForStage(0).foreach { task => + // executors 1 & 2 blacklisted for node + // executor 0 blacklisted just for partition 0 + if (task.index == 0) { + assert(task.executorId === "executor3") + } else { + assert(Set("executor0", "executor3").contains(task.executorId)) + } + } + tasksForStage(1).foreach { task => + // executor 3 blacklisted + assert("executor3" != task.executorId) + } + // no restrictions on stage 2 + + // Finally, just make sure that we can still complete tasks as usual with blacklisting + // in effect. Finish each of the tasksets -- taskset 0 & 1 complete successfully, taskset 2 + // fails. + (0 to 2).foreach { stageId => + val tasks = tasksForStage(stageId) + val tsm = taskScheduler.taskSetManagerForAttempt(stageId, 0).get + val valueSer = SparkEnv.get.serializer.newInstance() + if (stageId == 2) { + // Just need to make one task fail 4 times. + var task = tasks(0) + val taskIndex = task.index + (0 until 4).foreach { attempt => + assert(task.attemptNumber === attempt) + tsm.handleFailedTask(task.taskId, TaskState.FAILED, TaskResultLost) + val nextAttempts = + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("executor4", "host4", 1))).flatten + if (attempt < 3) { + assert(nextAttempts.size === 1) + task = nextAttempts(0) + assert(task.index === taskIndex) + } else { + assert(nextAttempts.size === 0) + } + } + // End the other task of the taskset, doesn't matter whether it succeeds or fails. + val otherTask = tasks(1) + val result = new DirectTaskResult[Int](valueSer.serialize(otherTask.taskId), Seq()) + tsm.handleSuccessfulTask(otherTask.taskId, result) + } else { + tasks.foreach { task => + val result = new DirectTaskResult[Int](valueSer.serialize(task.taskId), Seq()) + tsm.handleSuccessfulTask(task.taskId, result) + } + } + assert(tsm.isZombie) + } + + // the tasksSets complete, so the tracker should be notified of the successful ones + verify(blacklist, times(1)).updateBlacklistForSuccessfulTaskSet( + stageId = 0, + stageAttemptId = 0, + failuresByExec = stageToMockTaskSetBlacklist(0).execToFailures) + verify(blacklist, times(1)).updateBlacklistForSuccessfulTaskSet( + stageId = 1, + stageAttemptId = 0, + failuresByExec = stageToMockTaskSetBlacklist(1).execToFailures) + // but we shouldn't update for the failed taskset + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet( + stageId = meq(2), + stageAttemptId = anyInt(), + failuresByExec = anyObject()) + } + + test("scheduled tasks obey node and executor blacklists") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + (0 to 2).foreach { stageId => + val taskSet = FakeTask.createTaskSet(numTasks = 2, stageId = stageId, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + } + + val offers = IndexedSeq( + new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1), + new WorkerOffer("executor2", "host1", 1), + new WorkerOffer("executor3", "host2", 10), + new WorkerOffer("executor4", "host3", 1) + ) + + // setup our mock blacklist: + // host1, executor0 & executor3 are completely blacklisted + // This covers everything *except* one core on executor4 / host3, so that everything is still + // schedulable. + when(blacklist.isNodeBlacklisted("host1")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor0")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor3")).thenReturn(true) + + val stageToTsm = (0 to 2).map { stageId => + val tsm = taskScheduler.taskSetManagerForAttempt(stageId, 0).get + stageId -> tsm + }.toMap + + val firstTaskAttempts = taskScheduler.resourceOffers(offers).flatten + firstTaskAttempts.foreach { task => logInfo(s"scheduled $task on ${task.executorId}") } + assert(firstTaskAttempts.size === 1) + assert(firstTaskAttempts.head.executorId === "executor4") + ('0' until '2').foreach { hostNum => + verify(blacklist, atLeast(1)).isNodeBlacklisted("host" + hostNum) + } + } + + test("abort stage when all executors are blacklisted") { + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val taskSet = FakeTask.createTaskSet(numTasks = 10, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + val tsm = stageToMockTaskSetManager(0) + + // first just submit some offers so the scheduler knows about all the executors + taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 2), + WorkerOffer("executor1", "host0", 2), + WorkerOffer("executor2", "host0", 2), + WorkerOffer("executor3", "host1", 2) + )) + + // now say our blacklist updates to blacklist a bunch of resources, but *not* everything + when(blacklist.isNodeBlacklisted("host1")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor0")).thenReturn(true) + + // make an offer on the blacklisted resources. We won't schedule anything, but also won't + // abort yet, since we know of other resources that work + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 2), + WorkerOffer("executor3", "host1", 2) + )).flatten.size === 0) + assert(!tsm.isZombie) + + // now update the blacklist so that everything really is blacklisted + when(blacklist.isExecutorBlacklisted("executor1")).thenReturn(true) + when(blacklist.isExecutorBlacklisted("executor2")).thenReturn(true) + assert(taskScheduler.resourceOffers(IndexedSeq( + WorkerOffer("executor0", "host0", 2), + WorkerOffer("executor3", "host1", 2) + )).flatten.size === 0) + assert(tsm.isZombie) + verify(tsm).abort(anyString(), anyObject()) + } + + /** + * Helper for performance tests. Takes the explicitly blacklisted nodes and executors; verifies + * that the blacklists are used efficiently to ensure scheduling is not O(numPendingTasks). + * Creates 1 offer on executor[1-3]. Executor1 & 2 are on host1, executor3 is on host2. Passed + * in nodes and executors should be on that list. + */ + private def testBlacklistPerformance( + testName: String, + nodeBlacklist: Seq[String], + execBlacklist: Seq[String]): Unit = { + // Because scheduling involves shuffling the order of offers around, we run this test a few + // times to cover more possibilities. There are only 3 offers, which means 6 permutations, + // so 10 iterations is pretty good. + (0 until 10).foreach { testItr => + test(s"$testName: iteration $testItr") { + // When an executor or node is blacklisted, we want to make sure that we don't try + // scheduling each pending task, one by one, to discover they are all blacklisted. This is + // important for performance -- if we did check each task one-by-one, then responding to a + // resource offer (which is usually O(1)-ish) would become O(numPendingTasks), which would + // slow down scheduler throughput and slow down scheduling even on healthy executors. + // Here, we check a proxy for the runtime -- we make sure the scheduling is short-circuited + // at the node or executor blacklist, so we never check the per-task blacklist. We also + // make sure we don't check the node & executor blacklist for the entire taskset + // O(numPendingTasks) times. + + taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + // we schedule 500 tasks so we can clearly distinguish anything that is O(numPendingTasks) + val taskSet = FakeTask.createTaskSet(numTasks = 500, stageId = 0, stageAttemptId = 0) + taskScheduler.submitTasks(taskSet) + + val offers = IndexedSeq( + new WorkerOffer("executor1", "host1", 1), + new WorkerOffer("executor2", "host1", 1), + new WorkerOffer("executor3", "host2", 1) + ) + // We should check the node & exec blacklists, but only O(numOffers), not O(numPendingTasks) + // times. In the worst case, after shuffling, we offer our blacklisted resource first, and + // then offer other resources which do get used. The taskset blacklist is consulted + // repeatedly as we offer resources to the taskset -- each iteration either schedules + // something, or it terminates that locality level, so the maximum number of checks is + // numCores + numLocalityLevels + val numCoresOnAllOffers = offers.map(_.cores).sum + val numLocalityLevels = TaskLocality.values.size + val maxBlacklistChecks = numCoresOnAllOffers + numLocalityLevels + + // Setup the blacklist + nodeBlacklist.foreach { node => + when(stageToMockTaskSetBlacklist(0).isNodeBlacklistedForTaskSet(node)).thenReturn(true) + } + execBlacklist.foreach { exec => + when(stageToMockTaskSetBlacklist(0).isExecutorBlacklistedForTaskSet(exec)) + .thenReturn(true) + } + + // Figure out which nodes have any effective blacklisting on them. This means all nodes + // that are explicitly blacklisted, plus those that have *any* executors blacklisted. + val nodesForBlacklistedExecutors = offers.filter { offer => + execBlacklist.contains(offer.executorId) + }.map(_.host).toSet.toSeq + val nodesWithAnyBlacklisting = (nodeBlacklist ++ nodesForBlacklistedExecutors).toSet + // Similarly, figure out which executors have any blacklisting. This means all executors + // that are explicitly blacklisted, plus all executors on nodes that are blacklisted. + val execsForBlacklistedNodes = offers.filter { offer => + nodeBlacklist.contains(offer.host) + }.map(_.executorId).toSeq + val executorsWithAnyBlacklisting = (execBlacklist ++ execsForBlacklistedNodes).toSet + + // Schedule a taskset, and make sure our test setup is correct -- we are able to schedule + // a task on all executors that aren't blacklisted (whether that executor is a explicitly + // blacklisted, or implicitly blacklisted via the node blacklist). + val firstTaskAttempts = taskScheduler.resourceOffers(offers).flatten + assert(firstTaskAttempts.size === offers.size - executorsWithAnyBlacklisting.size) + + // Now check that we haven't made too many calls to any of the blacklist methods. + // We should be checking our node blacklist, but it should be within the bound we defined + // above. + verify(stageToMockTaskSetBlacklist(0), atMost(maxBlacklistChecks)) + .isNodeBlacklistedForTaskSet(anyString()) + // We shouldn't ever consult the per-task blacklist for the nodes that have been blacklisted + // for the entire taskset, since the taskset level blacklisting should prevent scheduling + // from ever looking at specific tasks. + nodesWithAnyBlacklisting.foreach { node => + verify(stageToMockTaskSetBlacklist(0), never) + .isNodeBlacklistedForTask(meq(node), anyInt()) + } + executorsWithAnyBlacklisting.foreach { exec => + // We should be checking our executor blacklist, but it should be within the bound defined + // above. Its possible that this will be significantly fewer calls, maybe even 0, if + // there is also a node-blacklist which takes effect first. But this assert is all we + // need to avoid an O(numPendingTask) slowdown. + verify(stageToMockTaskSetBlacklist(0), atMost(maxBlacklistChecks)) + .isExecutorBlacklistedForTaskSet(exec) + // We shouldn't ever consult the per-task blacklist for executors that have been + // blacklisted for the entire taskset, since the taskset level blacklisting should prevent + // scheduling from ever looking at specific tasks. + verify(stageToMockTaskSetBlacklist(0), never) + .isExecutorBlacklistedForTask(meq(exec), anyInt()) + } + } + } + } + + testBlacklistPerformance( + testName = "Blacklisted node for entire task set prevents per-task blacklist checks", + nodeBlacklist = Seq("host1"), + execBlacklist = Seq()) + + testBlacklistPerformance( + testName = "Blacklisted executor for entire task set prevents per-task blacklist checks", + nodeBlacklist = Seq(), + execBlacklist = Seq("executor3") + ) + test("abort stage if executor loss results in unschedulability from previously failed tasks") { // Make sure we can detect when a taskset becomes unschedulable from a blacklisting. This // test explores a particular corner case -- you may have one task fail, but still be @@ -301,27 +632,27 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B )).flatten assert(Set("executor0", "executor1") === firstTaskAttempts.map(_.executorId).toSet) - // fail one of the tasks, but leave the other running + // Fail one of the tasks, but leave the other running. val failedTask = firstTaskAttempts.find(_.executorId == "executor0").get taskScheduler.handleFailedTask(tsm, failedTask.taskId, TaskState.FAILED, TaskResultLost) - // at this point, our failed task could run on the other executor, so don't give up the task + // At this point, our failed task could run on the other executor, so don't give up the task // set yet. assert(!failedTaskSet) // Now we fail our second executor. The other task can still run on executor1, so make an offer - // on that executor, and make sure that the other task (not the failed one) is assigned there + // on that executor, and make sure that the other task (not the failed one) is assigned there. taskScheduler.executorLost("executor1", SlaveLost("oops")) val nextTaskAttempts = taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))).flatten // Note: Its OK if some future change makes this already realize the taskset has become - // unschedulable at this point (though in the current implementation, we're sure it will not) + // unschedulable at this point (though in the current implementation, we're sure it will not). assert(nextTaskAttempts.size === 1) assert(nextTaskAttempts.head.executorId === "executor0") assert(nextTaskAttempts.head.attemptNumber === 1) assert(nextTaskAttempts.head.index != failedTask.index) - // now we should definitely realize that our task set is unschedulable, because the only - // task left can't be scheduled on any executors due to the blacklist + // Now we should definitely realize that our task set is unschedulable, because the only + // task left can't be scheduled on any executors due to the blacklist. taskScheduler.resourceOffers(IndexedSeq(new WorkerOffer("executor0", "host0", 1))) sc.listenerBus.waitUntilEmpty(100000) assert(tsm.isZombie) @@ -408,4 +739,175 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(thirdTaskDescs.size === 0) assert(taskScheduler.getExecutorsAliveOnHost("host1") === Some(Set("executor1", "executor3"))) } + + test("scheduler checks for executors that can be expired from blacklist") { + taskScheduler = setupScheduler() + + taskScheduler.submitTasks(FakeTask.createTaskSet(1, 0)) + taskScheduler.resourceOffers(IndexedSeq( + new WorkerOffer("executor0", "host0", 1) + )).flatten + + verify(blacklist).applyBlacklistTimeout() + } + + test("if an executor is lost then the state for its running tasks is cleaned up (SPARK-18553)") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1)) + val attempt1 = FakeTask.createTaskSet(1) + + // submit attempt 1, offer resources, task gets scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten + assert(1 === taskDescriptions.length) + + // mark executor0 as dead + taskScheduler.executorLost("executor0", SlaveLost()) + assert(!taskScheduler.isExecutorAlive("executor0")) + assert(!taskScheduler.hasExecutorsAliveOnHost("host0")) + assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty) + + + // Check that state associated with the lost task attempt is cleaned up: + assert(taskScheduler.taskIdToExecutorId.isEmpty) + assert(taskScheduler.taskIdToTaskSetManager.isEmpty) + assert(taskScheduler.runningTasksByExecutors.get("executor0").isEmpty) + } + + test("if a task finishes with TaskState.LOST its executor is marked as dead") { + sc = new SparkContext("local", "TaskSchedulerImplSuite") + val taskScheduler = new TaskSchedulerImpl(sc) + taskScheduler.initialize(new FakeSchedulerBackend) + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + + val e0Offers = IndexedSeq(WorkerOffer("executor0", "host0", 1)) + val attempt1 = FakeTask.createTaskSet(1) + + // submit attempt 1, offer resources, task gets scheduled + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(e0Offers).flatten + assert(1 === taskDescriptions.length) + + // Report the task as failed with TaskState.LOST + taskScheduler.statusUpdate( + tid = taskDescriptions.head.taskId, + state = TaskState.LOST, + serializedData = ByteBuffer.allocate(0) + ) + + // Check that state associated with the lost task attempt is cleaned up: + assert(taskScheduler.taskIdToExecutorId.isEmpty) + assert(taskScheduler.taskIdToTaskSetManager.isEmpty) + assert(taskScheduler.runningTasksByExecutors.get("executor0").isEmpty) + + // Check that the executor has been marked as dead + assert(!taskScheduler.isExecutorAlive("executor0")) + assert(!taskScheduler.hasExecutorsAliveOnHost("host0")) + assert(taskScheduler.getExecutorsAliveOnHost("host0").isEmpty) + } + + test("Locality should be used for bulk offers even with delay scheduling off") { + val conf = new SparkConf() + .set("spark.locality.wait", "0") + sc = new SparkContext("local", "TaskSchedulerImplSuite", conf) + // we create a manual clock just so we can be sure the clock doesn't advance at all in this test + val clock = new ManualClock() + + // We customize the task scheduler just to let us control the way offers are shuffled, so we + // can be sure we try both permutations, and to control the clock on the tasksetmanager. + val taskScheduler = new TaskSchedulerImpl(sc) { + override def shuffleOffers(offers: IndexedSeq[WorkerOffer]): IndexedSeq[WorkerOffer] = { + // Don't shuffle the offers around for this test. Instead, we'll just pass in all + // the permutations we care about directly. + offers + } + override def createTaskSetManager(taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock) + } + } + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + taskScheduler.initialize(new FakeSchedulerBackend) + + // Make two different offers -- one in the preferred location, one that is not. + val offers = IndexedSeq( + WorkerOffer("exec1", "host1", 1), + WorkerOffer("exec2", "host2", 1) + ) + Seq(false, true).foreach { swapOrder => + // Submit a taskset with locality preferences. + val taskSet = FakeTask.createTaskSet( + 1, stageId = 1, stageAttemptId = 0, Seq(TaskLocation("host1", "exec1"))) + taskScheduler.submitTasks(taskSet) + val shuffledOffers = if (swapOrder) offers.reverse else offers + // Regardless of the order of the offers (after the task scheduler shuffles them), we should + // always take advantage of the local offer. + val taskDescs = taskScheduler.resourceOffers(shuffledOffers).flatten + withClue(s"swapOrder = $swapOrder") { + assert(taskDescs.size === 1) + assert(taskDescs.head.executorId === "exec1") + } + } + } + + test("With delay scheduling off, tasks can be run at any locality level immediately") { + val conf = new SparkConf() + .set("spark.locality.wait", "0") + sc = new SparkContext("local", "TaskSchedulerImplSuite", conf) + + // we create a manual clock just so we can be sure the clock doesn't advance at all in this test + val clock = new ManualClock() + val taskScheduler = new TaskSchedulerImpl(sc) { + override def createTaskSetManager(taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock) + } + } + // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. + new DAGScheduler(sc, taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) {} + override def executorAdded(execId: String, host: String) {} + } + taskScheduler.initialize(new FakeSchedulerBackend) + // make an offer on the preferred host so the scheduler knows its alive. This is necessary + // so that the taskset knows that it *could* take advantage of locality. + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec1", "host1", 1))) + + // Submit a taskset with locality preferences. + val taskSet = FakeTask.createTaskSet( + 1, stageId = 1, stageAttemptId = 0, Seq(TaskLocation("host1", "exec1"))) + taskScheduler.submitTasks(taskSet) + val tsm = taskScheduler.taskSetManagerForAttempt(1, 0).get + // make sure we've setup our test correctly, so that the taskset knows it *could* use local + // offers. + assert(tsm.myLocalityLevels.contains(TaskLocality.NODE_LOCAL)) + // make an offer on a non-preferred location. Since the delay is 0, we should still schedule + // immediately. + val taskDescs = + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec2", "host2", 1))).flatten + assert(taskDescs.size === 1) + assert(taskDescs.head.executorId === "exec2") + } + + test("TaskScheduler should throw IllegalArgumentException when schedulingMode is not supported") { + intercept[IllegalArgumentException] { + val taskScheduler = setupScheduler( + TaskSchedulerImpl.SCHEDULER_MODE_PROPERTY -> SchedulingMode.NONE.toString) + taskScheduler.initialize(new FakeSchedulerBackend) + } + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index 8c902af5685ff..6b52c10b2c68b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -85,9 +85,9 @@ class TaskSetBlacklistSuite extends SparkFunSuite { Seq("exec1", "exec2").foreach { exec => assert( - execToFailures(exec).taskToFailureCount === Map( - 0 -> 1, - 1 -> 1 + execToFailures(exec).taskToFailureCountAndFailureTime === Map( + 0 -> (1, 0), + 1 -> (1, 0) ) ) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 1b1a764ceff95..db14c9acfdce5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,16 +17,21 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Mockito.{mock, verify} +import org.mockito.Matchers.{any, anyInt, anyString} +import org.mockito.Mockito.{mock, never, spy, verify, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerInstance +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, ManualClock} class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler) @@ -181,7 +186,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdates = taskSet.tasks.head.metrics.internalAccums // Offer a host with NO_PREF as the constraint, @@ -189,6 +194,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) + clock.advance(1) // Tell it the task has finished manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates)) assert(sched.endedTasks(0) === Success) @@ -234,7 +240,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execC", "host2")) val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "execB"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // An executor that is not NODE_LOCAL should be rejected. assert(manager.resourceOffer("execC", "host2", ANY) === None) @@ -255,7 +261,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq() // Last task has no locality prefs ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) == None) @@ -284,7 +290,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq() // Last task has no locality prefs ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1, exec1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0) assert(manager.resourceOffer("exec3", "host2", PROCESS_LOCAL).get.index === 1) @@ -304,7 +310,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host2")) ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -342,7 +348,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host3")) ) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // First offer host1: first task should be chosen assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -374,7 +380,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) @@ -391,7 +398,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. @@ -424,7 +432,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // affinity to exec1 on host1 - which we will fail. val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "exec1"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, 4, clock) + clock.advance(1) + // We don't directly use the application blacklist, but its presence triggers blacklisting + // within the taskset. + val mockListenerBus = mock(classOf[LiveListenerBus]) + val blacklistTrackerOpt = Some(new BlacklistTracker(mockListenerBus, conf, None, clock)) + val manager = new TaskSetManager(sched, taskSet, 4, blacklistTrackerOpt, clock) { val offerResult = manager.resourceOffer("exec1", "host1", PROCESS_LOCAL) @@ -513,7 +526,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host2", "execC")), Seq()) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY))) // Add a new executor @@ -544,7 +557,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execB")), Seq(TaskLocation("host2", "execC")), Seq()) - val manager = new TaskSetManager(sched, taskSet, 1, new ManualClock) + val clock = new ManualClock() + clock.advance(1) + val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) sched.addExecutor("execA", "host1") manager.executorAdded() sched.addExecutor("execC", "host2") @@ -577,7 +592,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host1", "execA"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY))) // Set allowed locality to ANY @@ -658,6 +673,71 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(thrown2.getMessage().contains("bigger than spark.driver.maxResultSize")) } + test("[SPARK-13931] taskSetManager should not send Resubmitted tasks after being a zombie") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = {} + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val singleTask = new ShuffleMapTask(0, 0, null, new Partition { + override def index: Int = 0 + }, Seq(TaskLocation("host1", "execA")), new Properties, null) + val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, null) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + // Offer host1, which should be accepted as a PROCESS_LOCAL location + // by the one task in the task set + val task1 = manager.resourceOffer("execA", "host1", TaskLocality.PROCESS_LOCAL).get + + // Mark the task as available for speculation, and then offer another resource, + // which should be used to launch a speculative copy of the task. + manager.speculatableTasks += singleTask.partitionId + val task2 = manager.resourceOffer("execB", "host2", TaskLocality.ANY).get + + assert(manager.runningTasks === 2) + assert(manager.isZombie === false) + + val directTaskResult = new DirectTaskResult[String](null, Seq()) { + override def value(resultSer: SerializerInstance): String = "" + } + // Complete one copy of the task, which should result in the task set manager + // being marked as a zombie, because at least one copy of its only task has completed. + manager.handleSuccessfulTask(task1.taskId, directTaskResult) + assert(manager.isZombie === true) + assert(resubmittedTasks === 0) + assert(manager.runningTasks === 1) + + manager.executorLost("execB", "host2", new SlaveLost()) + assert(manager.runningTasks === 0) + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( @@ -668,7 +748,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(), Seq(TaskLocation("host3", "execC"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0) assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None) @@ -696,7 +776,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(), Seq(TaskLocation("host3"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // node-local tasks are scheduled without delay assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 0) @@ -718,7 +798,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(ExecutorCacheTaskLocation("host1", "execA")), Seq(ExecutorCacheTaskLocation("host2", "execB"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // process-local tasks are scheduled first assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 2) @@ -738,7 +818,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(ExecutorCacheTaskLocation("host1", "execA")), Seq(ExecutorCacheTaskLocation("host2", "execB"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // process-local tasks are scheduled first assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 1) @@ -758,7 +838,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host1", "execA")), Seq(TaskLocation("host2", "execB.1"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) // Only ANY is valid assert(manager.myLocalityLevels.sameElements(Array(ANY))) // Add a new executor @@ -792,7 +872,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg Seq(TaskLocation("host2")), Seq(TaskLocation("hdfs_cache_host3"))) val clock = new ManualClock - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) sched.removeExecutor("execA") manager.executorAdded() @@ -819,8 +899,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(4) // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => task.metrics.internalAccums } @@ -836,6 +917,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(task.executorId === k) } assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) // Complete the 3 tasks and leave 1 task in running for (id <- Set(0, 1, 2)) { manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) @@ -859,7 +941,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Complete the speculative attempt for the running task manager.handleSuccessfulTask(4, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt - verify(sched.backend).killTask(3, "exec2", true) + verify(sched.backend).killTask(3, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. @@ -873,8 +955,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Set the speculation multiplier to be 0 so speculative tasks are launched immediately sc.conf.set("spark.speculation.multiplier", "0.0") sc.conf.set("spark.speculation.quantile", "0.6") + sc.conf.set("spark.speculation", "true") val clock = new ManualClock() - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => task.metrics.internalAccums } @@ -893,6 +976,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg tasks += task } assert(sched.startedTasks.toSet === (0 until 5).toSet) + clock.advance(1) // Complete 3 tasks and leave 2 tasks in running for (id <- Set(0, 1, 2)) { manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) @@ -945,14 +1029,14 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.handleSuccessfulTask(speculativeTask.taskId, createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask = originalTasks(speculativeTask.index) - verify(sched.backend).killTask(origTask.taskId, "exec2", true) + verify(sched.backend).killTask(origTask.taskId, "exec2", true, "another attempt succeeded") // Because the SchedulerBackend was a mock, the 2nd copy of the task won't actually be // killed, so the FakeTaskScheduler is only told about the successful completion // of the speculated task. assert(sched.endedTasks(3) === Success) // also because the scheduler is a mock, our manager isn't notified about the task killed event, // so we do that manually - manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled) + manager.handleFailedTask(origTask.taskId, TaskState.KILLED, TaskKilled("test")) // this task has "failed" 4 times, but one of them doesn't count, so keep running the stage assert(manager.tasksSuccessful === 4) assert(!manager.isZombie) @@ -969,29 +1053,93 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg createTaskResult(3, accumUpdatesByTask(3))) // Verify that it kills other running attempt val origTask2 = originalTasks(speculativeTask2.index) - verify(sched.backend).killTask(origTask2.taskId, "exec2", true) + verify(sched.backend).killTask(origTask2.taskId, "exec2", true, "another attempt succeeded") assert(manager.tasksSuccessful === 5) assert(manager.isZombie) } + + test("SPARK-19868: DagScheduler only notified of taskEnd when state is ready") { + // dagScheduler.taskEnded() is async, so it may *seem* ok to call it before we've set all + // appropriate state, eg. isZombie. However, this sets up a race that could go the wrong way. + // This is a super-focused regression test which checks the zombie state as soon as + // dagScheduler.taskEnded() is called, to ensure we haven't introduced a race. + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val mockDAGScheduler = mock(classOf[DAGScheduler]) + sched.dagScheduler = mockDAGScheduler + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) + val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) + assert(taskOption.isDefined) + // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon + manager.handleSuccessfulTask(0, createTaskResult(0)) + } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, new ManualClock) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock) assert(manager.name === "TaskSet_0.0") // Make sure a task set with the same stage ID but different attempt ID has a unique name val taskSet2 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 1) - val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, new ManualClock) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, clock = new ManualClock) assert(manager2.name === "TaskSet_0.1") // Make sure a task set with the same attempt ID but different stage ID also has a unique name val taskSet3 = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 1) - val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, new ManualClock) + val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, clock = new ManualClock) assert(manager3.name === "TaskSet_1.1") } + test("don't update blacklist for shuffle-fetch failures, preemption, denied commits, " + + "or killed tasks") { + // Setup a taskset, and fail some tasks for a fetch failure, preemption, denied commit, + // and killed task. + val conf = new SparkConf(). + set(config.BLACKLIST_ENABLED, true) + sc = new SparkContext("local", "test", conf) + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(4) + val tsm = new TaskSetManager(sched, taskSet, 4) + // we need a spy so we can attach our mock blacklist + val tsmSpy = spy(tsm) + val blacklist = mock(classOf[TaskSetBlacklist]) + when(tsmSpy.taskSetBlacklistHelperOpt).thenReturn(Some(blacklist)) + + // make some offers to our taskset, to get tasks we will fail + val taskDescs = Seq( + "exec1" -> "host1", + "exec2" -> "host1" + ).flatMap { case (exec, host) => + // offer each executor twice (simulating 2 cores per executor) + (0 until 2).flatMap{ _ => tsmSpy.resourceOffer(exec, host, TaskLocality.ANY)} + } + assert(taskDescs.size === 4) + + // now fail those tasks + tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED, + ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) + tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, + TaskCommitDenied(0, 2, 0)) + tsmSpy.handleFailedTask(taskDescs(3).taskId, TaskState.KILLED, TaskKilled("test")) + + // Make sure that the blacklist ignored all of the task failures above, since they aren't + // the fault of the executor where the task was running. + verify(blacklist, never()) + .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala index 81eb907ac7ba6..608052f5ed855 100644 --- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala @@ -16,25 +16,31 @@ */ package org.apache.spark.security -import java.security.PrivilegedExceptionAction +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} +import java.nio.channels.Channels +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files +import java.util.{Arrays, Random, UUID} -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import com.google.common.io.ByteStreams -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.internal.config._ +import org.apache.spark.network.util.CryptoUtils import org.apache.spark.security.CryptoStreamUtils._ +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} +import org.apache.spark.storage.TempShuffleBlockId class CryptoStreamUtilsSuite extends SparkFunSuite { - val ugi = UserGroupInformation.createUserForTesting("testuser", Array("testgroup")) - test("Crypto configuration conversion") { + test("crypto configuration conversion") { val sparkKey1 = s"${SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX}a.b.c" val sparkVal1 = "val1" - val cryptoKey1 = s"${COMMONS_CRYPTO_CONF_PREFIX}a.b.c" + val cryptoKey1 = s"${CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX}a.b.c" val sparkKey2 = SPARK_IO_ENCRYPTION_COMMONS_CONFIG_PREFIX.stripSuffix(".") + "A.b.c" val sparkVal2 = "val2" - val cryptoKey2 = s"${COMMONS_CRYPTO_CONF_PREFIX}A.b.c" + val cryptoKey2 = s"${CryptoUtils.COMMONS_CRYPTO_CONFIG_PREFIX}A.b.c" val conf = new SparkConf() conf.set(sparkKey1, sparkVal1) conf.set(sparkKey2, sparkVal2) @@ -43,65 +49,125 @@ class CryptoStreamUtilsSuite extends SparkFunSuite { assert(!props.containsKey(cryptoKey2)) } - test("Shuffle encryption is disabled by default") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - initCredentials(conf, credentials) - assert(credentials.getSecretKey(SPARK_IO_TOKEN) === null) - } - }) + test("shuffle encryption key length should be 128 by default") { + val conf = createConf() + var key = CryptoStreamUtils.createKey(conf) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 128) } - test("Shuffle encryption key length should be 128 by default") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - conf.set(IO_ENCRYPTION_ENABLED, true) - initCredentials(conf, credentials) - var key = credentials.getSecretKey(SPARK_IO_TOKEN) - assert(key !== null) - val actual = key.length * (java.lang.Byte.SIZE) - assert(actual === 128) - } - }) + test("create 256-bit key") { + val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "256") + var key = CryptoStreamUtils.createKey(conf) + val actual = key.length * (java.lang.Byte.SIZE) + assert(actual === 256) } - test("Initial credentials with key length in 256") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 256) - conf.set(IO_ENCRYPTION_ENABLED, true) - initCredentials(conf, credentials) - var key = credentials.getSecretKey(SPARK_IO_TOKEN) - assert(key !== null) - val actual = key.length * (java.lang.Byte.SIZE) - assert(actual === 256) - } - }) + test("create key with invalid length") { + intercept[IllegalArgumentException] { + val conf = createConf(IO_ENCRYPTION_KEY_SIZE_BITS.key -> "328") + CryptoStreamUtils.createKey(conf) + } } - test("Initial credentials with invalid key length") { - ugi.doAs(new PrivilegedExceptionAction[Unit]() { - override def run(): Unit = { - val credentials = UserGroupInformation.getCurrentUser.getCredentials() - val conf = new SparkConf() - conf.set(IO_ENCRYPTION_KEY_SIZE_BITS, 328) - conf.set(IO_ENCRYPTION_ENABLED, true) - val thrown = intercept[IllegalArgumentException] { - initCredentials(conf, credentials) - } - } - }) + test("serializer manager integration") { + val conf = createConf() + .set("spark.shuffle.compress", "true") + .set("spark.shuffle.spill.compress", "true") + + val plainStr = "hello world" + val blockId = new TempShuffleBlockId(UUID.randomUUID()) + val key = Some(CryptoStreamUtils.createKey(conf)) + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf, + encryptionKey = key) + + val outputStream = new ByteArrayOutputStream() + val wrappedOutputStream = serializerManager.wrapStream(blockId, outputStream) + wrappedOutputStream.write(plainStr.getBytes(UTF_8)) + wrappedOutputStream.close() + + val encryptedBytes = outputStream.toByteArray + val encryptedStr = new String(encryptedBytes, UTF_8) + assert(plainStr !== encryptedStr) + + val inputStream = new ByteArrayInputStream(encryptedBytes) + val wrappedInputStream = serializerManager.wrapStream(blockId, inputStream) + val decryptedBytes = ByteStreams.toByteArray(wrappedInputStream) + val decryptedStr = new String(decryptedBytes, UTF_8) + assert(decryptedStr === plainStr) } - private[this] def initCredentials(conf: SparkConf, credentials: Credentials): Unit = { - if (conf.get(IO_ENCRYPTION_ENABLED)) { - SecurityManager.initIOEncryptionKey(conf, credentials) + test("encryption key propagation to executors") { + val conf = createConf().setAppName("Crypto Test").setMaster("local-cluster[1,1,1024]") + val sc = new SparkContext(conf) + try { + val content = "This is the content to be encrypted." + val encrypted = sc.parallelize(Seq(1)) + .map { str => + val bytes = new ByteArrayOutputStream() + val out = CryptoStreamUtils.createCryptoOutputStream(bytes, SparkEnv.get.conf, + SparkEnv.get.securityManager.getIOEncryptionKey().get) + out.write(content.getBytes(UTF_8)) + out.close() + bytes.toByteArray() + }.collect()(0) + + assert(content != encrypted) + + val in = CryptoStreamUtils.createCryptoInputStream(new ByteArrayInputStream(encrypted), + sc.conf, SparkEnv.get.securityManager.getIOEncryptionKey().get) + val decrypted = new String(ByteStreams.toByteArray(in), UTF_8) + assert(content === decrypted) + } finally { + sc.stop() + } + } + + test("crypto stream wrappers") { + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = createConf() + val key = createKey(conf) + val file = Files.createTempFile("crypto", ".test").toFile() + + val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key) + try { + ByteStreams.copy(new ByteArrayInputStream(testData), outStream) + } finally { + outStream.close() + } + + val inStream = createCryptoInputStream(new FileInputStream(file), conf, key) + try { + val inStreamData = ByteStreams.toByteArray(inStream) + assert(Arrays.equals(inStreamData, testData)) + } finally { + inStream.close() + } + + val outChannel = createWritableChannel(new FileOutputStream(file).getChannel(), conf, key) + try { + val inByteChannel = Channels.newChannel(new ByteArrayInputStream(testData)) + ByteStreams.copy(inByteChannel, outChannel) + } finally { + outChannel.close() + } + + val inChannel = createReadableChannel(new FileInputStream(file).getChannel(), conf, key) + try { + val inChannelData = ByteStreams.toByteArray(Channels.newInputStream(inChannel)) + assert(Arrays.equals(inChannelData, testData)) + } finally { + inChannel.close() } } + + private def createConf(extra: (String, String)*): SparkConf = { + val conf = new SparkConf() + extra.foreach { case (k, v) => conf.set(k, v) } + conf.set(IO_ENCRYPTION_ENABLED, true) + conf + } + } diff --git a/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala new file mode 100644 index 0000000000000..3f52dc41abf6d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/EncryptionFunSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ + +trait EncryptionFunSuite { + + this: SparkFunSuite => + + /** + * Runs a test twice, initializing a SparkConf object with encryption off, then on. It's ok + * for the test to modify the provided SparkConf. + */ + final protected def encryptionTest(name: String)(fn: SparkConf => Unit) { + Seq(false, true).foreach { encrypt => + test(s"$name (encryption = ${ if (encrypt) "on" else "off" })") { + val conf = new SparkConf().set(IO_ENCRYPTION_ENABLED, encrypt) + fn(conf) + } + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 5040841811054..7c3922e47fbb9 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag -import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import org.roaringbitmap.RoaringBitmap @@ -76,6 +76,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("basic types") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -106,6 +109,9 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("pairs") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) @@ -130,12 +136,16 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { } test("Scala data structures") { + val conf = new SparkConf(false) + conf.set("spark.kryo.registrationRequired", "true") + val ser = new KryoSerializer(conf).newInstance() def check[T: ClassTag](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } check(List[Int]()) check(List[Int](1, 2, 3)) + check(Seq[Int](1, 2, 3)) check(List[String]()) check(List[String]("x", "y", "z")) check(None) @@ -351,6 +361,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val ser = new KryoSerializer(conf).newInstance() val thrown = intercept[SparkException](ser.serialize(largeObject)) assert(thrown.getMessage.contains(kryoBufferMaxProperty)) + assert(thrown.getCause.isInstanceOf[KryoException]) } test("SPARK-12222: deserialize RoaringBitmap throw Buffer underflow exception") { diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala index 4ce3b941bea55..99882bf76e29d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializerPropertiesSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset /** * Tests to ensure that [[Serializer]] implementations obey the API contracts for methods that * describe properties of the serialized stream, such as - * [[Serializer.supportsRelocationOfSerializedObjects]]. + * `Serializer.supportsRelocationOfSerializedObjects`. */ class SerializerPropertiesSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 442941685f1ae..85ccb33471048 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte )).thenAnswer(new Answer[DiskBlockObjectWriter] { override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments + val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( args(1).asInstanceOf[File], + manager, args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], - wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 89ed031b6fcd1..f0c521b00b583 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.UUID + import org.apache.spark.SparkFunSuite class BlockIdSuite extends SparkFunSuite { @@ -67,6 +69,32 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("shuffle data") { + val id = ShuffleDataBlockId(4, 5, 6) + assertSame(id, ShuffleDataBlockId(4, 5, 6)) + assertDifferent(id, ShuffleDataBlockId(6, 5, 6)) + assert(id.name === "shuffle_4_5_6.data") + assert(id.asRDDId === None) + assert(id.shuffleId === 4) + assert(id.mapId === 5) + assert(id.reduceId === 6) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle index") { + val id = ShuffleIndexBlockId(7, 8, 9) + assertSame(id, ShuffleIndexBlockId(7, 8, 9)) + assertDifferent(id, ShuffleIndexBlockId(9, 8, 9)) + assert(id.name === "shuffle_7_8_9.index") + assert(id.asRDDId === None) + assert(id.shuffleId === 7) + assert(id.mapId === 8) + assert(id.reduceId === 9) + assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("broadcast") { val id = BroadcastBlockId(42) assertSame(id, BroadcastBlockId(42)) @@ -101,6 +129,30 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("temp local") { + val id = TempLocalBlockId(new UUID(5, 2)) + assertSame(id, TempLocalBlockId(new UUID(5, 2))) + assertDifferent(id, TempLocalBlockId(new UUID(5, 3))) + assert(id.name === "temp_local_00000000-0000-0005-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 5) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + + test("temp shuffle") { + val id = TempShuffleBlockId(new UUID(1, 2)) + assertSame(id, TempShuffleBlockId(new UUID(1, 2))) + assertDifferent(id, TempShuffleBlockId(new UUID(1, 3))) + assert(id.name === "temp_shuffle_00000000-0000-0001-0000-000000000002") + assert(id.asRDDId === None) + assert(id.isBroadcast === false) + assert(id.id.getMostSignificantBits() === 1) + assert(id.id.getLeastSignificantBits() === 2) + assert(!id.isShuffle) + } + test("test") { val id = TestBlockId("abc") assertSame(id, TestBlockId("abc")) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index f4bfdc2fd69a9..c100803279eaf 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.Locale + import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions @@ -28,6 +30,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager +import org.apache.spark.internal.Logging import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService @@ -36,33 +39,33 @@ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ +import org.apache.spark.util.Utils + +trait BlockManagerReplicationBehavior extends SparkFunSuite + with Matchers + with BeforeAndAfter + with LocalSparkContext { -/** Testsuite that tests block replication in BlockManager */ -class BlockManagerReplicationSuite extends SparkFunSuite - with Matchers - with BeforeAndAfter - with LocalSparkContext { + val conf: SparkConf - private val conf = new SparkConf(false).set("spark.app.id", "test") - private var rpcEnv: RpcEnv = null - private var master: BlockManagerMaster = null - private val securityMgr = new SecurityManager(conf) - private val bcastManager = new BroadcastManager(true, conf, securityMgr) - private val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) - private val shuffleManager = new SortShuffleManager(conf) + protected var rpcEnv: RpcEnv = null + protected var master: BlockManagerMaster = null + protected lazy val securityMgr = new SecurityManager(conf) + protected lazy val bcastManager = new BroadcastManager(true, conf, securityMgr) + protected lazy val mapOutputTracker = new MapOutputTrackerMaster(conf, bcastManager, true) + protected lazy val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped // after the unit test. - private val allStores = new ArrayBuffer[BlockManager] + protected val allStores = new ArrayBuffer[BlockManager] // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test - conf.set("spark.kryoserializer.buffer", "1m") - private val serializer = new KryoSerializer(conf) + protected lazy val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) - private def makeBlockManager( + protected def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { conf.set("spark.testing.memory", maxMem.toString) @@ -355,7 +358,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite * is correct. Then it also drops the block from memory of each store (using LRU) and * again checks whether the master's knowledge gets updated. */ - private def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { + protected def testReplication(maxReplication: Int, storageLevels: Seq[StorageLevel]) { import org.apache.spark.storage.StorageLevel._ assert(maxReplication > 1, @@ -373,9 +376,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite storageLevels.foreach { storageLevel => // Put the block into one of the stores - val blockId = new TestBlockId( - "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase) - stores(0).putSingle(blockId, new Array[Byte](blockSize), storageLevel) + val blockId = TestBlockId( + "block-with-" + storageLevel.description.replace(" ", "-").toLowerCase(Locale.ROOT)) + val testValue = Array.fill[Byte](blockSize)(1) + stores(0).putSingle(blockId, testValue, storageLevel) // Assert that master know two locations for the block val blockLocations = master.getLocations(blockId).map(_.executorId).toSet @@ -387,12 +391,23 @@ class BlockManagerReplicationSuite extends SparkFunSuite testStore => blockLocations.contains(testStore.blockManagerId.executorId) }.foreach { testStore => val testStoreName = testStore.blockManagerId.executorId - assert( - testStore.getLocalValues(blockId).isDefined, s"$blockId was not found in $testStoreName") - testStore.releaseLock(blockId) + val blockResultOpt = testStore.getLocalValues(blockId) + assert(blockResultOpt.isDefined, s"$blockId was not found in $testStoreName") + val localValues = blockResultOpt.get.data.toSeq + assert(localValues.size == 1) + assert(localValues.head === testValue) assert(master.getLocations(blockId).map(_.executorId).toSet.contains(testStoreName), s"master does not have status for ${blockId.name} in $testStoreName") + val memoryStore = testStore.memoryStore + if (memoryStore.contains(blockId) && !storageLevel.deserialized) { + memoryStore.getBytes(blockId).get.chunks.foreach { byteBuffer => + assert(storageLevel.useOffHeap == byteBuffer.isDirect, + s"memory mode ${storageLevel.memoryMode} is not compatible with " + + byteBuffer.getClass.getSimpleName) + } + } + val blockStatus = master.getBlockStatus(blockId)(testStore.blockManagerId) // Assert that block status in the master for this store has expected storage level @@ -448,3 +463,95 @@ class BlockManagerReplicationSuite extends SparkFunSuite } } } + +class BlockManagerReplicationSuite extends BlockManagerReplicationBehavior { + val conf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") +} + +class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehavior { + val conf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set("spark.storage.replication.proactive", "true") + conf.set("spark.storage.exceptionOnPinLeak", "true") + + (2 to 5).foreach { i => + test(s"proactive block replication - $i replicas - ${i - 1} block manager deletions") { + testProactiveReplication(i) + } + } + + def testProactiveReplication(replicationFactor: Int) { + val blockSize = 1000 + val storeSize = 10000 + val initialStores = (1 to 10).map { i => makeBlockManager(storeSize, s"store$i") } + + val blockId = "a1" + + val storageLevel = StorageLevel(true, true, false, true, replicationFactor) + initialStores.head.putSingle(blockId, new Array[Byte](blockSize), storageLevel) + + val blockLocations = master.getLocations(blockId) + logInfo(s"Initial locations : $blockLocations") + + assert(blockLocations.size === replicationFactor) + + // remove a random blockManager + val executorsToRemove = blockLocations.take(replicationFactor - 1).toSet + logInfo(s"Removing $executorsToRemove") + initialStores.filter(bm => executorsToRemove.contains(bm.blockManagerId)).foreach { bm => + master.removeExecutor(bm.blockManagerId.executorId) + bm.stop() + // giving enough time for replication to happen and new block be reported to master + eventually(timeout(5 seconds), interval(100 millis)) { + val newLocations = master.getLocations(blockId).toSet + assert(newLocations.size === replicationFactor) + } + } + + val newLocations = eventually(timeout(5 seconds), interval(100 millis)) { + val _newLocations = master.getLocations(blockId).toSet + assert(_newLocations.size === replicationFactor) + _newLocations + } + logInfo(s"New locations : $newLocations") + + // new locations should not contain stopped block managers + assert(newLocations.forall(bmId => !executorsToRemove.contains(bmId)), + "New locations contain stopped block managers.") + + // Make sure all locks have been released. + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + assert(bm.blockInfoManager.getTaskLockCount(BlockInfo.NON_TASK_WRITER) === 0) + } + } + } +} + +class DummyTopologyMapper(conf: SparkConf) extends TopologyMapper(conf) with Logging { + // number of racks to test with + val numRacks = 3 + + /** + * Gets the topology information given the host name + * + * @param hostname Hostname + * @return random topology + */ + override def getTopologyForHost(hostname: String): Option[String] = { + Some(s"/Rack-${Utils.random.nextInt(numRacks)}") + } +} + +class BlockManagerBasicStrategyReplicationSuite extends BlockManagerReplicationBehavior { + val conf: SparkConf = new SparkConf(false).set("spark.app.id", "test") + conf.set("spark.kryoserializer.buffer", "1m") + conf.set( + "spark.storage.replication.policy", + classOf[BasicBlockReplicationPolicy].getName) + conf.set( + "spark.storage.replication.topologyMapper", + classOf[DummyTopologyMapper].getName) +} + diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 705c355234425..1e7bcdb6740f6 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -35,6 +35,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -42,6 +43,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -49,7 +51,8 @@ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach - with PrivateMethodTester with LocalSparkContext with ResetSystemProperties { + with PrivateMethodTester with LocalSparkContext with ResetSystemProperties + with EncryptionFunSuite { import BlockManagerSuite._ @@ -75,16 +78,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, - transferService: Option[BlockTransferService] = Option.empty): BlockManager = { - conf.set("spark.testing.memory", maxMem.toString) - conf.set("spark.memory.offHeap.size", maxMem.toString) - val serializer = new KryoSerializer(conf) + transferService: Option[BlockTransferService] = Option.empty, + testConf: Option[SparkConf] = None): BlockManager = { + val bmConf = testConf.map(_.setAll(conf.getAll)).getOrElse(conf) + bmConf.set("spark.testing.memory", maxMem.toString) + bmConf.set("spark.memory.offHeap.size", maxMem.toString) + val serializer = new KryoSerializer(bmConf) + val encryptionKey = if (bmConf.get(IO_ENCRYPTION_ENABLED)) { + Some(CryptoStreamUtils.createKey(bmConf)) + } else { + None + } + val bmSecurityMgr = new SecurityManager(bmConf, encryptionKey) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) - val memManager = UnifiedMemoryManager(conf, numCores = 1) - val serializerManager = new SerializerManager(serializer, conf) - val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + val memManager = UnifiedMemoryManager(bmConf, numCores = 1) + val serializerManager = new SerializerManager(serializer, bmConf) + val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManager @@ -394,7 +405,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - val reregister = !master.driverEndpoint.askWithRetry[Boolean]( + val reregister = !master.driverEndpoint.askSync[Boolean]( BlockManagerHeartbeat(store.blockManagerId)) assert(reregister == true) } @@ -485,8 +496,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } - test("optimize a location order of blocks") { - val localHost = Utils.localHostName() + test("optimize a location order of blocks without topology information") { + val localHost = "localhost" val otherHost = "otherHost" val bmMaster = mock(classOf[BlockManagerMaster]) val bmId1 = BlockManagerId("id1", localHost, 1) @@ -497,7 +508,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) - assert(locations.map(_.host).toSet === Set(localHost, localHost, otherHost)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + + test("optimize a location order of blocks with topology information") { + val localHost = "localhost" + val otherHost = "otherHost" + val localRack = "localRack" + val otherRack = "otherRack" + + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1, Some(localRack)) + val bmId2 = BlockManagerId("id2", localHost, 2, Some(localRack)) + val bmId3 = BlockManagerId("id3", otherHost, 3, Some(otherRack)) + val bmId4 = BlockManagerId("id4", otherHost, 4, Some(otherRack)) + val bmId5 = BlockManagerId("id5", otherHost, 5, Some(localRack)) + when(bmMaster.getLocations(mc.any[BlockId])) + .thenReturn(Seq(bmId1, bmId2, bmId5, bmId3, bmId4)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + blockManager.blockManagerId = + BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) + assert(locations.flatMap(_.topologyInfo) + === Seq(localRack, localRack, localRack, otherRack, otherRack)) } test("SPARK-9591: getRemoteBytes from another location when Exception throw") { @@ -610,8 +646,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } - test("on-disk storage") { - store = makeBlockManager(1200) + encryptionTest("on-disk storage") { _conf => + store = makeBlockManager(1200, testConf = Some(_conf)) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -623,34 +659,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a1").isDefined, "a1 was in store") } - test("disk and memory storage") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false) + encryptionTest("disk and memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = false, testConf = conf) } - test("disk and memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true) + encryptionTest("disk and memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, getAsBytes = true, testConf = conf) } - test("disk and memory storage with serialization") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false) + encryptionTest("disk and memory storage with serialization") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = false, testConf = conf) } - test("disk and memory storage with serialization and getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) + encryptionTest("disk and memory storage with serialization and getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true, testConf = conf) } - test("disk and off-heap memory storage") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + encryptionTest("disk and off-heap memory storage") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false, testConf = conf) } - test("disk and off-heap memory storage with getLocalBytes") { - testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + encryptionTest("disk and off-heap memory storage with getLocalBytes") { _conf => + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true, testConf = conf) } def testDiskAndMemoryStorage( storageLevel: StorageLevel, - getAsBytes: Boolean): Unit = { - store = makeBlockManager(12000) + getAsBytes: Boolean, + testConf: SparkConf): Unit = { + store = makeBlockManager(12000, testConf = Some(testConf)) val accessMethod = if (getAsBytes) store.getLocalBytesAndReleaseLock else store.getSingleAndReleaseLock val a1 = new Array[Byte](4000) @@ -678,8 +715,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } - test("LRU with mixed storage levels") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) @@ -700,8 +737,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getSingleAndReleaseLock("a4").isDefined, "a4 was not in store") } - test("in-memory LRU with streams") { - store = makeBlockManager(12000) + encryptionTest("in-memory LRU with streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -728,8 +765,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(store.getAndReleaseLock("list3") === None, "list1 was in store") } - test("LRU with mixed storage levels and streams") { - store = makeBlockManager(12000) + encryptionTest("LRU with mixed storage levels and streams") { _conf => + store = makeBlockManager(12000, testConf = Some(_conf)) val list1 = List(new Array[Byte](2000), new Array[Byte](2000)) val list2 = List(new Array[Byte](2000), new Array[Byte](2000)) val list3 = List(new Array[Byte](2000), new Array[Byte](2000)) @@ -1325,7 +1362,8 @@ private object BlockManagerSuite { val getAndReleaseLock: (BlockId) => Option[BlockResult] = wrapGet(store.get) val getSingleAndReleaseLock: (BlockId) => Option[Any] = wrapGet(store.getSingle) val getLocalBytesAndReleaseLock: (BlockId) => Option[ChunkedByteBuffer] = { - wrapGet(store.getLocalBytes) + val allocator = ByteBuffer.allocate _ + wrapGet { bid => store.getLocalBytes(bid).map(_.toChunkedByteBuffer(allocator)) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index 800c3899f1a72..4000218e71a8b 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,34 +18,35 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions +import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} import org.apache.spark.{LocalSparkContext, SparkFunSuite} -class BlockReplicationPolicySuite extends SparkFunSuite +class RandomBlockReplicationPolicyBehavior extends SparkFunSuite with Matchers with BeforeAndAfter with LocalSparkContext { // Implicitly convert strings to BlockIds for test clarity. - private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + protected implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) + val replicationPolicy: BlockReplicationPolicy = new RandomBlockReplicationPolicy + + val blockId = "test-block" /** * Test if we get the required number of peers when using random sampling from - * RandomBlockReplicationPolicy + * BlockReplicationPolicy */ - test(s"block replication - random block replication policy") { + test("block replication - random block replication policy") { val numBlockManagers = 10 val storeSize = 1000 - val blockManagers = (1 to numBlockManagers).map { i => - BlockManagerId(s"store-$i", "localhost", 1000 + i, None) - } + val blockManagers = generateBlockManagerIds(numBlockManagers, Seq("/Rack-1")) val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None) - val replicationPolicy = new RandomBlockReplicationPolicy - val blockId = "test-block" - (1 to 10).foreach {numReplicas => + (1 to 10).foreach { numReplicas => logDebug(s"Num replicas : $numReplicas") val randomPeers = replicationPolicy.prioritize( candidateBlockManager, @@ -68,7 +69,69 @@ class BlockReplicationPolicySuite extends SparkFunSuite logDebug(s"Random peers : ${secondPass.mkString(", ")}") assert(secondPass.toSet.size === numReplicas) } + } + + /** + * Returns a sequence of [[BlockManagerId]], whose rack is randomly picked from the given `racks`. + * Note that, each rack will be picked at least once from `racks`, if `count` is greater or equal + * to the number of `racks`. + */ + protected def generateBlockManagerIds(count: Int, racks: Seq[String]): Seq[BlockManagerId] = { + val randomizedRacks: Seq[String] = Random.shuffle( + racks ++ racks.length.until(count).map(_ => racks(Random.nextInt(racks.length))) + ) + (0 until count).map { i => + BlockManagerId(s"Exec-$i", s"Host-$i", 10000 + i, Some(randomizedRacks(i))) + } } +} +class TopologyAwareBlockReplicationPolicyBehavior extends RandomBlockReplicationPolicyBehavior { + override val replicationPolicy = new BasicBlockReplicationPolicy + + test("All peers in the same rack") { + val racks = Seq("/default-rack") + val numBlockManager = 10 + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(numBlockManager, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 10001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + assert(prioritizedPeers.forall(p => p.host != blockManager.host)) + } + } + + test("Peers in 2 racks") { + val racks = Seq("/Rack-1", "/Rack-2") + (1 to 10).foreach {numReplicas => + val peers = generateBlockManagerIds(10, racks) + val blockManager = BlockManagerId("Driver", "Host-driver", 9001, Some(racks.head)) + + val prioritizedPeers = replicationPolicy.prioritize( + blockManager, + peers, + mutable.HashSet.empty, + blockId, + numReplicas + ) + + assert(prioritizedPeers.toSet.size == numReplicas) + val priorityPeers = prioritizedPeers.take(2) + assert(priorityPeers.forall(p => p.host != blockManager.host)) + if(numReplicas > 1) { + // both these conditions should be satisfied when numReplicas > 1 + assert(priorityPeers.exists(p => p.topologyInfo == blockManager.topologyInfo)) + assert(priorityPeers.exists(p => p.topologyInfo != blockManager.topologyInfo)) + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index bbfd6df3b6990..7859b0bba2b48 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import scala.language.reflectiveCalls - import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 684e978d11864..bfb3ac4c15bca 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.util.Utils class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("verify write metrics") { + private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = { val file = new File(tempDir, "somefile") + val conf = new SparkConf() + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true, + writeMetrics) + (writer, file, writeMetrics) + } + + test("verify write metrics") { + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("verify write metrics on revert") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("Reopening a closed block writer") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() writer.open() writer.close() @@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() after commit() should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("revertPartialWritesAndClose() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() without ever opening or writing") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() val segment = writer.commitAndGet() writer.close() assert(segment.length === 0) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala index 9e6b02b9eac4d..67fc084e8a13d 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskStoreSuite.scala @@ -18,15 +18,23 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} -import java.util.Arrays +import java.util.{Arrays, Random} -import org.apache.spark.{SparkConf, SparkFunSuite} +import com.google.common.io.{ByteStreams, Files} +import io.netty.channel.FileRegion + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.util.{ByteArrayWritableChannel, JavaUtils} +import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.util.io.ChunkedByteBuffer import org.apache.spark.util.Utils class DiskStoreSuite extends SparkFunSuite { test("reads of memory-mapped and non memory-mapped files are equivalent") { + val conf = new SparkConf() + val securityManager = new SecurityManager(conf) + // It will cause error when we tried to re-open the filestore and the // memory-mapped byte buffer tot he file has not been GC on Windows. assume(!Utils.isWindows) @@ -37,16 +45,18 @@ class DiskStoreSuite extends SparkFunSuite { val byteBuffer = new ChunkedByteBuffer(ByteBuffer.wrap(bytes)) val blockId = BlockId("rdd_1_2") - val diskBlockManager = new DiskBlockManager(new SparkConf(), deleteFilesOnStop = true) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) - val diskStoreMapped = new DiskStore(new SparkConf().set(confKey, "0"), diskBlockManager) + val diskStoreMapped = new DiskStore(conf.clone().set(confKey, "0"), diskBlockManager, + securityManager) diskStoreMapped.putBytes(blockId, byteBuffer) - val mapped = diskStoreMapped.getBytes(blockId) + val mapped = diskStoreMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer assert(diskStoreMapped.remove(blockId)) - val diskStoreNotMapped = new DiskStore(new SparkConf().set(confKey, "1m"), diskBlockManager) + val diskStoreNotMapped = new DiskStore(conf.clone().set(confKey, "1m"), diskBlockManager, + securityManager) diskStoreNotMapped.putBytes(blockId, byteBuffer) - val notMapped = diskStoreNotMapped.getBytes(blockId) + val notMapped = diskStoreNotMapped.getBytes(blockId).asInstanceOf[ByteBufferBlockData].buffer // Not possible to do isInstanceOf due to visibility of HeapByteBuffer assert(notMapped.getChunks().forall(_.getClass.getName.endsWith("HeapByteBuffer")), @@ -63,4 +73,95 @@ class DiskStoreSuite extends SparkFunSuite { assert(Arrays.equals(mapped.toArray, bytes)) assert(Arrays.equals(notMapped.toArray, bytes)) } + + test("block size tracking") { + val conf = new SparkConf() + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, new SecurityManager(conf)) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(new Array[Byte](32)) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === 32L) + diskStore.remove(blockId) + assert(diskStore.getSize(blockId) === 0L) + } + + test("block data encryption") { + val testDir = Utils.createTempDir() + val testData = new Array[Byte](128 * 1024) + new Random().nextBytes(testData) + + val conf = new SparkConf() + val securityManager = new SecurityManager(conf, Some(CryptoStreamUtils.createKey(conf))) + val diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) + val diskStore = new DiskStore(conf, diskBlockManager, securityManager) + + val blockId = BlockId("rdd_1_2") + diskStore.put(blockId) { chan => + val buf = ByteBuffer.wrap(testData) + while (buf.hasRemaining()) { + chan.write(buf) + } + } + + assert(diskStore.getSize(blockId) === testData.length) + + val diskData = Files.toByteArray(diskBlockManager.getFile(blockId.name)) + assert(!Arrays.equals(testData, diskData)) + + val blockData = diskStore.getBytes(blockId) + assert(blockData.isInstanceOf[EncryptedBlockData]) + assert(blockData.size === testData.length) + Map( + "input stream" -> readViaInputStream _, + "chunked byte buffer" -> readViaChunkedByteBuffer _, + "nio byte buffer" -> readViaNioBuffer _, + "managed buffer" -> readViaManagedBuffer _ + ).foreach { case (name, fn) => + val readData = fn(blockData) + assert(readData.length === blockData.size, s"Size of data read via $name did not match.") + assert(Arrays.equals(testData, readData), s"Data read via $name did not match.") + } + } + + private def readViaInputStream(data: BlockData): Array[Byte] = { + val is = data.toInputStream() + try { + ByteStreams.toByteArray(is) + } finally { + is.close() + } + } + + private def readViaChunkedByteBuffer(data: BlockData): Array[Byte] = { + val buf = data.toChunkedByteBuffer(ByteBuffer.allocate _) + try { + buf.toArray + } finally { + buf.dispose() + } + } + + private def readViaNioBuffer(data: BlockData): Array[Byte] = { + JavaUtils.bufferToArray(data.toByteBuffer()) + } + + private def readViaManagedBuffer(data: BlockData): Array[Byte] = { + val region = data.toNetty().asInstanceOf[FileRegion] + val byteChannel = new ByteArrayWritableChannel(data.size.toInt) + + while (region.transfered() < region.count()) { + region.transferTo(byteChannel, region.transfered()) + } + + byteChannel.close() + byteChannel.getData + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index c7074078d8fd2..f7b3a2754f0ea 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import org.scalatest.BeforeAndAfter @@ -33,9 +33,13 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + after { + Utils.clearLocalRootDirs() + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) @@ -43,7 +47,7 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) @@ -51,4 +55,17 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { assert(new File(Utils.getLocalDir(conf)).exists()) } + test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { + val path1 = "/NONEXISTENT_PATH_ONE" + val path2 = "/NONEXISTENT_PATH_TWO" + assert(!new File(path1).exists()) + assert(!new File(path2).exists()) + val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") + val message = intercept[IOException] { + Utils.getLocalDir(conf) + }.getMessage + // If any temporary directory could not be retrieved under the given paths above, it should + // throw an exception with the message that includes the paths. + assert(message.contains(s"$path1,$path2")) + } } diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala index ec4f2637fadd0..535105379963a 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -67,7 +67,8 @@ class PartiallySerializedBlockSuite spy } - val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val serializer = serializerManager + .getSerializer(implicitly[ClassTag[T]], autoPick = true).newInstance() val redirectableOutputStream = Mockito.spy(new RedirectableOutputStream) redirectableOutputStream.setOutputStream(bbos) val serializationStream = Mockito.spy(serializer.serializeStream(redirectableOutputStream)) @@ -144,7 +145,7 @@ class PartiallySerializedBlockSuite try { TaskContext.setTaskContext(TaskContext.empty()) val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() Mockito.verifyNoMoreInteractions(memoryStore) } finally { @@ -182,7 +183,8 @@ class PartiallySerializedBlockSuite Mockito.verifyNoMoreInteractions(memoryStore) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer, atLeastOnce).dispose() - val serializer = serializerManager.getSerializer(implicitly[ClassTag[T]]).newInstance() + val serializer = serializerManager + .getSerializer(implicitly[ClassTag[T]], autoPick = true).newInstance() val deserialized = serializer.deserializeStream(new ByteBufferInputStream(bbos.toByteBuffer)).asIterator.toSeq assert(deserialized === items) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e3ec99685f73c..9900d1edc4cb0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.InputStream +import java.io.{File, InputStream, IOException} import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global @@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ -import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException @@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Create a mock managed buffer for testing def createMockManagedBuffer(): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) - when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + val in = mock(classOf[InputStream]) + when(in.read(any())).thenReturn(1) + when(in.read(any(), any(), any())).thenReturn(1) + when(mockManagedBuffer.createInputStream()).thenReturn(in) mockManagedBuffer } @@ -99,8 +103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + true) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -172,8 +178,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -184,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() - taskContext.markTaskCompleted() + taskContext.markTaskCompleted(None) verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() // The 3rd block should not be retained because the iterator is already in zombie state @@ -201,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() ) // Semaphore to coordinate event sequence in two different threads. @@ -235,8 +243,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager, blocksByAddress, + (_, in) => in, 48 * 1024 * 1024, - Int.MaxValue) + Int.MaxValue, + true) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -247,4 +257,148 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } intercept[FetchFailedException] { iterator.next() } } + + test("retry corrupt blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) + sem.release() + } + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), + 48 * 1024 * 1024, + Int.MaxValue, + true) + + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 0)) + + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + sem.release() + } + } + }) + + // The next block is corrupt local block (the second one is corrupt and retried) + intercept[FetchFailedException] { iterator.next() } + + sem.acquire() + intercept[FetchFailedException] { iterator.next() } + } + + test("retry corrupt blocks (disabled)") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer() + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, corruptBuffer) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptBuffer) + sem.release() + } + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 100), + 48 * 1024 * 1024, + Int.MaxValue, + false) + + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 0)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 1, 0)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 2, 0)) + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala index e5733aebf607c..da198f946fd64 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageSuite.scala @@ -27,7 +27,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, and remove (for non-RDD blocks) private def storageStatus1: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.blocks.isEmpty) assert(status.rddBlocks.isEmpty) assert(status.memUsed === 0L) @@ -74,7 +74,7 @@ class StorageSuite extends SparkFunSuite { // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD blocks private def storageStatus2: StorageStatus = { - val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) assert(status.rddBlocks.isEmpty) status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) status.addBlock(TestBlockId("man"), BlockStatus(memAndDisk, 10L, 20L)) @@ -252,9 +252,9 @@ class StorageSuite extends SparkFunSuite { // For testing StorageUtils.updateRddInfo and StorageUtils.getRddBlockLocations private def stockStorageStatuses: Seq[StorageStatus] = { - val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L) - val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L) - val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L) + val status1 = new StorageStatus(BlockManagerId("big", "dog", 1), 1000L, Some(1000L), Some(0L)) + val status2 = new StorageStatus(BlockManagerId("fat", "duck", 2), 2000L, Some(2000L), Some(0L)) + val status3 = new StorageStatus(BlockManagerId("fat", "cat", 3), 3000L, Some(3000L), Some(0L)) status1.addBlock(RDDBlockId(0, 0), BlockStatus(memAndDisk, 1L, 2L)) status1.addBlock(RDDBlockId(0, 1), BlockStatus(memAndDisk, 1L, 2L)) status2.addBlock(RDDBlockId(0, 2), BlockStatus(memAndDisk, 1L, 2L)) @@ -332,4 +332,81 @@ class StorageSuite extends SparkFunSuite { assert(blockLocations1(RDDBlockId(1, 2)) === Seq("cat:3")) } + private val offheap = StorageLevel.OFF_HEAP + // For testing add, update, remove, get, and contains etc. for both RDD and non-RDD onheap + // and offheap blocks + private def storageStatus3: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, Some(1000L), Some(1000L)) + assert(status.rddBlocks.isEmpty) + status.addBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(TestBlockId("man"), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(0, 0), BlockStatus(offheap, 10L, 0L)) + status.addBlock(RDDBlockId(1, 1), BlockStatus(offheap, 100L, 0L)) + status.addBlock(RDDBlockId(2, 2), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 3), BlockStatus(memAndDisk, 10L, 20L)) + status.addBlock(RDDBlockId(2, 4), BlockStatus(memAndDisk, 10L, 40L)) + status + } + + test("storage memUsed, diskUsed with on-heap and off-heap blocks") { + val status = storageStatus3 + def actualMemUsed: Long = status.blocks.values.map(_.memSize).sum + def actualDiskUsed: Long = status.blocks.values.map(_.diskSize).sum + + def actualOnHeapMemUsed: Long = + status.blocks.values.filter(!_.storageLevel.useOffHeap).map(_.memSize).sum + def actualOffHeapMemUsed: Long = + status.blocks.values.filter(_.storageLevel.useOffHeap).map(_.memSize).sum + + assert(status.maxMem === status.maxOnHeapMem.get + status.maxOffHeapMem.get) + + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + assert(status.memRemaining === status.maxMem - actualMemUsed) + assert(status.onHeapMemRemaining.get === status.maxOnHeapMem.get - actualOnHeapMemUsed) + assert(status.offHeapMemRemaining.get === status.maxOffHeapMem.get - actualOffHeapMemUsed) + + status.addBlock(TestBlockId("wire"), BlockStatus(memAndDisk, 400L, 500L)) + status.addBlock(RDDBlockId(25, 25), BlockStatus(memAndDisk, 40L, 50L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + + status.updateBlock(TestBlockId("dan"), BlockStatus(memAndDisk, 4L, 5L)) + status.updateBlock(RDDBlockId(0, 0), BlockStatus(offheap, 4L, 0L)) + status.updateBlock(RDDBlockId(1, 1), BlockStatus(offheap, 4L, 0L)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + assert(status.onHeapMemUsed.get === actualOnHeapMemUsed) + assert(status.offHeapMemUsed.get === actualOffHeapMemUsed) + + status.removeBlock(TestBlockId("fire")) + status.removeBlock(TestBlockId("man")) + status.removeBlock(RDDBlockId(2, 2)) + status.removeBlock(RDDBlockId(2, 3)) + assert(status.memUsed === actualMemUsed) + assert(status.diskUsed === actualDiskUsed) + } + + private def storageStatus4: StorageStatus = { + val status = new StorageStatus(BlockManagerId("big", "dog", 1), 2000L, None, None) + status + } + test("old SparkListenerBlockManagerAdded event compatible") { + // This scenario will only be happened when replaying old event log. In this scenario there's + // no block add or remove event replayed, so only total amount of memory is valid. + val status = storageStatus4 + assert(status.maxMem === status.maxMemory) + + assert(status.memUsed === 0L) + assert(status.diskUsed === 0L) + assert(status.onHeapMemUsed === None) + assert(status.offHeapMemUsed === None) + + assert(status.memRemaining === status.maxMem) + assert(status.onHeapMemRemaining === None) + assert(status.offHeapMemRemaining === None) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala index d30b987d6ca31..499d47b13d702 100644 --- a/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/StagePageSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui +import java.util.Locale import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -35,26 +36,16 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { private val peakExecutionMemory = 10 - test("peak execution memory only displayed if unsafe is enabled") { - val unsafeConf = "spark.sql.unsafe.enabled" - val conf = new SparkConf(false).set(unsafeConf, "true") - val html = renderStagePage(conf).toString().toLowerCase + test("peak execution memory should displayed") { + val conf = new SparkConf(false) + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) val targetString = "peak execution memory" assert(html.contains(targetString)) - // Disable unsafe and make sure it's not there - val conf2 = new SparkConf(false).set(unsafeConf, "false") - val html2 = renderStagePage(conf2).toString().toLowerCase - assert(!html2.contains(targetString)) - // Avoid setting anything; it should be displayed by default - val conf3 = new SparkConf(false) - val html3 = renderStagePage(conf3).toString().toLowerCase - assert(html3.contains(targetString)) } test("SPARK-10543: peak execution memory should be per-task rather than cumulative") { - val unsafeConf = "spark.sql.unsafe.enabled" - val conf = new SparkConf(false).set(unsafeConf, "true") - val html = renderStagePage(conf).toString().toLowerCase + val conf = new SparkConf(false) + val html = renderStagePage(conf).toString().toLowerCase(Locale.ROOT) // verify min/25/50/75/max show task value not cumulative values assert(html.contains(s"$peakExecutionMemory.0 b" * 5)) } @@ -87,7 +78,7 @@ class StagePageSuite extends SparkFunSuite with LocalSparkContext { val taskInfo = new TaskInfo(taskId, taskId, 0, 0, "0", "localhost", TaskLocality.ANY, false) jobListener.onStageSubmitted(SparkListenerStageSubmitted(stageInfo)) jobListener.onTaskStart(SparkListenerTaskStart(0, 0, taskInfo)) - taskInfo.markFinished(TaskState.FINISHED) + taskInfo.markFinished(TaskState.FINISHED, System.currentTimeMillis()) val taskMetrics = TaskMetrics.empty taskMetrics.incPeakExecutionMemory(peakExecutionMemory) jobListener.onTaskEnd( diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index e5d408a167361..bdd148875e38a 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} +import java.util.Locale import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -39,7 +40,7 @@ import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.status.api.v1.{JacksonMessageWriter, StageStatus} +import org.apache.spark.status.api.v1.{JacksonMessageWriter, RDDDataDistribution, StageStatus} private[spark] class SparkUICssErrorHandler extends DefaultCssErrorHandler { @@ -103,6 +104,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B .set("spark.ui.enabled", "true") .set("spark.ui.port", "0") .set("spark.ui.killEnabled", killEnabled.toString) + .set("spark.memory.offHeap.size", "64m") val sc = new SparkContext(conf) assert(sc.ui.isDefined) sc @@ -151,6 +153,39 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val updatedRddJson = getJson(ui, "storage/rdd/0") (updatedRddJson \ "storageLevel").extract[String] should be ( StorageLevels.MEMORY_ONLY.description) + + val dataDistributions0 = + (updatedRddJson \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions0.length should be (1) + val dist0 = dataDistributions0.head + + dist0.onHeapMemoryUsed should not be (None) + dist0.memoryUsed should be (dist0.onHeapMemoryUsed.get) + dist0.onHeapMemoryRemaining should not be (None) + dist0.offHeapMemoryRemaining should not be (None) + dist0.memoryRemaining should be ( + dist0.onHeapMemoryRemaining.get + dist0.offHeapMemoryRemaining.get) + dist0.onHeapMemoryUsed should not be (Some(0L)) + dist0.offHeapMemoryUsed should be (Some(0L)) + + rdd.unpersist() + rdd.persist(StorageLevels.OFF_HEAP).count() + val updatedStorageJson1 = getJson(ui, "storage/rdd") + updatedStorageJson1.children.length should be (1) + val updatedRddJson1 = getJson(ui, "storage/rdd/0") + val dataDistributions1 = + (updatedRddJson1 \ "dataDistribution").extract[Seq[RDDDataDistribution]] + dataDistributions1.length should be (1) + val dist1 = dataDistributions1.head + + dist1.offHeapMemoryUsed should not be (None) + dist1.memoryUsed should be (dist1.offHeapMemoryUsed.get) + dist1.onHeapMemoryRemaining should not be (None) + dist1.offHeapMemoryRemaining should not be (None) + dist1.memoryRemaining should be ( + dist1.onHeapMemoryRemaining.get + dist1.offHeapMemoryRemaining.get) + dist1.onHeapMemoryUsed should be (Some(0L)) + dist1.offHeapMemoryUsed should not be (Some(0L)) } } @@ -419,8 +454,8 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B eventually(timeout(10 seconds), interval(50 milliseconds)) { goToUi(sc, "/jobs") findAll(cssSelector("tbody tr a")).foreach { link => - link.text.toLowerCase should include ("count") - link.text.toLowerCase should not include "unknown" + link.text.toLowerCase(Locale.ROOT) should include ("count") + link.text.toLowerCase(Locale.ROOT) should not include "unknown" } } } @@ -473,10 +508,10 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/stages/stage/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/stages/stage/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST - getResponseCode(url, "GET") should be (200) - getResponseCode(url, "POST") should be (200) + TestUtils.httpResponseCode(url, "GET") should be (200) + TestUtils.httpResponseCode(url, "POST") should be (200) } } } @@ -486,10 +521,10 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(1 to 10).map{x => Thread.sleep(10000); x}.countAsync() eventually(timeout(5 seconds), interval(50 milliseconds)) { val url = new URL( - sc.ui.get.appUIAddress.stripSuffix("/") + "/jobs/job/kill/?id=0") + sc.ui.get.webUrl.stripSuffix("/") + "/jobs/job/kill/?id=0") // SPARK-6846: should be POST only but YARN AM doesn't proxy POST - getResponseCode(url, "GET") should be (200) - getResponseCode(url, "POST") should be (200) + TestUtils.httpResponseCode(url, "GET") should be (200) + TestUtils.httpResponseCode(url, "POST") should be (200) } } } @@ -620,7 +655,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B test("live UI json application list") { withSpark(newSparkContext()) { sc => val appListRawJson = HistoryServerSuite.getUrl(new URL( - sc.ui.get.appUIAddress + "/api/v1/applications")) + sc.ui.get.webUrl + "/api/v1/applications")) val appListJsonAst = JsonMethods.parse(appListRawJson) appListJsonAst.children.length should be (1) val attempts = (appListJsonAst \ "attempts").children @@ -640,7 +675,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B sc.parallelize(Seq(1, 2, 3)).map(identity).groupBy(identity).map(identity).groupBy(identity) rdd.count() - val stage0 = Source.fromURL(sc.ui.get.appUIAddress + + val stage0 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=0&attempt=0&expandDagViz=true").mkString assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) @@ -651,7 +686,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("{\n label="groupBy";\n " + "2 [label="MapPartitionsRDD [2]")) - val stage1 = Source.fromURL(sc.ui.get.appUIAddress + + val stage1 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) @@ -662,7 +697,7 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage1.contains("{\n label="groupBy";\n " + "5 [label="MapPartitionsRDD [5]")) - val stage2 = Source.fromURL(sc.ui.get.appUIAddress + + val stage2 = Source.fromURL(sc.ui.get.webUrl + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) @@ -671,23 +706,12 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } } - def getResponseCode(url: URL, method: String): Int = { - val connection = url.openConnection().asInstanceOf[HttpURLConnection] - connection.setRequestMethod(method) - try { - connection.connect() - connection.getResponseCode() - } finally { - connection.disconnect() - } - } - def goToUi(sc: SparkContext, path: String): Unit = { goToUi(sc.ui.get, path) } def goToUi(ui: SparkUI, path: String): Unit = { - go to (ui.appUIAddress.stripSuffix("/") + path) + go to (ui.webUrl.stripSuffix("/") + path) } def parseDate(json: JValue): Long = { @@ -699,6 +723,6 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B } def apiUrl(ui: SparkUI, path: String): URL = { - new URL(ui.appUIAddress + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) + new URL(ui.webUrl + "/api/v1/applications/" + ui.sc.get.applicationId + "/" + path) } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 4abcfb7e51914..0c3d4caeeabf9 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -18,18 +18,20 @@ package org.apache.spark.ui import java.net.{BindException, ServerSocket} -import java.net.URI -import javax.servlet.http.HttpServletRequest +import java.net.{URI, URL} +import java.util.Locale +import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source -import org.eclipse.jetty.servlet.ServletContextHandler +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.mockito.Mockito.{mock, when} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.LocalSparkContext._ +import org.apache.spark.util.Utils class UISuite extends SparkFunSuite { @@ -52,13 +54,16 @@ class UISuite extends SparkFunSuite { (conf, new SecurityManager(conf).getSSLOptions("ui")) } - private def sslEnabledConf(): (SparkConf, SSLOptions) = { + private def sslEnabledConf(sslPort: Option[Int] = None): (SparkConf, SSLOptions) = { val keyStoreFilePath = getTestResourcePath("spark.keystore") val conf = new SparkConf() .set("spark.ssl.ui.enabled", "true") .set("spark.ssl.ui.keyStore", keyStoreFilePath) .set("spark.ssl.ui.keyStorePassword", "123456") .set("spark.ssl.ui.keyPassword", "123456") + sslPort.foreach { p => + conf.set("spark.ssl.ui.port", p.toString) + } (conf, new SecurityManager(conf).getSSLOptions("ui")) } @@ -66,12 +71,12 @@ class UISuite extends SparkFunSuite { withSpark(newSparkContext()) { sc => // test if the ui is visible, and all the expected tabs are visible eventually(timeout(10 seconds), interval(50 milliseconds)) { - val html = Source.fromURL(sc.ui.get.appUIAddress).mkString + val html = Source.fromURL(sc.ui.get.webUrl).mkString assert(!html.contains("random data that should not be present")) - assert(html.toLowerCase.contains("stages")) - assert(html.toLowerCase.contains("storage")) - assert(html.toLowerCase.contains("environment")) - assert(html.toLowerCase.contains("executors")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("storage")) + assert(html.toLowerCase(Locale.ROOT).contains("environment")) + assert(html.toLowerCase(Locale.ROOT).contains("executors")) } } } @@ -81,7 +86,7 @@ class UISuite extends SparkFunSuite { // test if visible from http://localhost:4040 eventually(timeout(10 seconds), interval(50 milliseconds)) { val html = Source.fromURL("http://localhost:4040").mkString - assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase(Locale.ROOT).contains("stages")) } } } @@ -167,6 +172,7 @@ class UISuite extends SparkFunSuite { val boundPort = serverInfo.boundPort assert(server.getState === "STARTED") assert(boundPort != 0) + assert(serverInfo.securePort.isDefined) intercept[BindException] { socket = new ServerSocket(boundPort) } @@ -176,19 +182,18 @@ class UISuite extends SparkFunSuite { } } - test("verify appUIAddress contains the scheme") { + test("verify webUrl contains the scheme") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val uiAddress = ui.appUIAddress - val uiHostPort = ui.appUIHostPort - assert(uiAddress.equals("http://" + uiHostPort)) + val uiAddress = ui.webUrl + assert(uiAddress.startsWith("http://") || uiAddress.startsWith("https://")) } } - test("verify appUIAddress contains the port") { + test("verify webUrl contains the port") { withSpark(newSparkContext()) { sc => val ui = sc.ui.get - val splitUIAddress = ui.appUIAddress.split(':') + val splitUIAddress = ui.webUrl.split(':') val boundPort = ui.boundPort assert(splitUIAddress(2).toInt == boundPort) } @@ -228,8 +233,77 @@ class UISuite extends SparkFunSuite { assert(newHeader === null) } + test("http -> https redirect applies to all URIs") { + var serverInfo: ServerInfo = null + try { + val servlet = new HttpServlet() { + override def doGet(req: HttpServletRequest, res: HttpServletResponse): Unit = { + res.sendError(HttpServletResponse.SC_OK) + } + } + + def newContext(path: String): ServletContextHandler = { + val ctx = new ServletContextHandler() + ctx.setContextPath(path) + ctx.addServlet(new ServletHolder(servlet), "/root") + ctx + } + + val (conf, sslOptions) = sslEnabledConf() + serverInfo = JettyUtils.startJettyServer("0.0.0.0", 0, sslOptions, + Seq[ServletContextHandler](newContext("/"), newContext("/test1")), + conf) + assert(serverInfo.server.getState === "STARTED") + + val testContext = newContext("/test2") + serverInfo.addHandler(testContext) + testContext.start() + + val httpPort = serverInfo.boundPort + + val tests = Seq( + ("http", serverInfo.boundPort, HttpServletResponse.SC_FOUND), + ("https", serverInfo.securePort.get, HttpServletResponse.SC_OK)) + + tests.foreach { case (scheme, port, expected) => + val urls = Seq( + s"$scheme://localhost:$port/root", + s"$scheme://localhost:$port/test1/root", + s"$scheme://localhost:$port/test2/root") + urls.foreach { url => + val rc = TestUtils.httpResponseCode(new URL(url)) + assert(rc === expected, s"Unexpected status $rc for $url") + } + } + } finally { + stopServer(serverInfo) + } + } + + test("specify both http and https ports separately") { + var socket: ServerSocket = null + var serverInfo: ServerInfo = null + try { + socket = new ServerSocket(0) + + // Make sure the SSL port lies way outside the "http + 400" range used as the default. + val baseSslPort = Utils.userPort(socket.getLocalPort(), 10000) + val (conf, sslOptions) = sslEnabledConf(sslPort = Some(baseSslPort)) + + serverInfo = JettyUtils.startJettyServer("0.0.0.0", socket.getLocalPort() + 1, + sslOptions, Seq[ServletContextHandler](), conf, "server1") + + val notAllowed = Utils.userPort(serverInfo.boundPort, 400) + assert(serverInfo.securePort.isDefined) + assert(serverInfo.securePort.get != Utils.userPort(serverInfo.boundPort, 400)) + } finally { + stopServer(serverInfo) + closeSocket(socket) + } + } + def stopServer(info: ServerInfo): Unit = { - if (info != null && info.server != null) info.server.stop + if (info != null) info.stop() } def closeSocket(socket: ServerSocket): Unit = { diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala index 6335d905c0fbf..423daacc0f5a5 100644 --- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala @@ -110,7 +110,7 @@ class UIUtilsSuite extends SparkFunSuite { } test("SPARK-11906: Progress bar should not overflow because of speculative tasks") { - val generated = makeProgressBar(2, 3, 0, 0, 0, 4).head.child.filter(_.label == "div") + val generated = makeProgressBar(2, 3, 0, 0, Map.empty, 4).head.child.filter(_.label == "div") val expected = Seq(
    ,
    @@ -133,6 +133,45 @@ class UIUtilsSuite extends SparkFunSuite { assert(decoded2 === decodeURLParameter(decoded2)) } + test("SPARK-20393: Prevent newline characters in parameters.") { + val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b" + val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b" + + assert(stripEncoding === stripXSS(encoding)) + } + + test("SPARK-20393: Prevent script from parameters running on page.") { + val scriptAlert = """>"'>